未验证 提交 7ecbcc08 编写于 作者: C chenxujun 提交者: GitHub

【Hackathon No.62】digamma, dirichlet算子FP16/BF16单测完善 (#52604)

* Add digamma, dirichlet tests

* Fix code
上级 eeb4d165
......@@ -15,9 +15,16 @@
#include "paddle/phi/kernels/digamma_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/digamma_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
digamma_grad, GPU, ALL_LAYOUT, phi::DigammaGradKernel, float, double) {}
PD_REGISTER_KERNEL(digamma_grad,
GPU,
ALL_LAYOUT,
phi::DigammaGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,10 +15,17 @@
#include "paddle/phi/kernels/digamma_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/digamma_kernel_impl.h"
PD_REGISTER_KERNEL(
digamma, GPU, ALL_LAYOUT, phi::DigammaKernel, float, double) {}
PD_REGISTER_KERNEL(digamma,
GPU,
ALL_LAYOUT,
phi::DigammaKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -112,5 +112,11 @@ struct DirichletSampler<GPUContext, T> {
};
} // namespace phi
PD_REGISTER_KERNEL(
dirichlet, GPU, ALL_LAYOUT, phi::Dirichletkernel, float, double) {}
PD_REGISTER_KERNEL(dirichlet,
GPU,
ALL_LAYOUT,
phi::Dirichletkernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -16,6 +16,7 @@
#include <unsupported/Eigen/SpecialFunctions>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
......@@ -27,7 +28,11 @@ struct DigammaGradFunctor {
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = dout_[idx] * Eigen::numext::polygamma(T(1), x_[idx]);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType mp_dout = static_cast<MPType>(dout_[idx]);
const MPType mp_x = static_cast<MPType>(x_[idx]);
output_[idx] =
static_cast<T>(mp_dout * Eigen::numext::polygamma(MPType(1), mp_x));
}
private:
......
......@@ -16,6 +16,7 @@
#include <unsupported/Eigen/SpecialFunctions>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
......@@ -27,7 +28,9 @@ struct DigammaFunctor {
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = Eigen::numext::digamma(input_[idx]);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType mp_input = static_cast<MPType>(input_[idx]);
output_[idx] = static_cast<T>(Eigen::numext::digamma(mp_input));
}
private:
......
......@@ -16,6 +16,7 @@
#include <cmath>
#include <random>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/dirichlet_kernel.h"
// ROCM hcc doesn't work well with using std:: in kernel functions
......@@ -47,7 +48,10 @@ template <typename ScalarT, typename SamplerT>
struct BaseSampler {
SamplerT sampler_;
HOSTDEVICE BaseSampler(const SamplerT& sampler) : sampler_(sampler) {}
HOSTDEVICE ScalarT sample() { return sampler_(); }
HOSTDEVICE ScalarT sample() {
// Sometimes convert float to float16/bfloat16
return static_cast<ScalarT>(sampler_());
}
};
// `sample_gamma` is d from Numpy's distributions.c, and add support for
......@@ -83,33 +87,40 @@ HOSTDEVICE ScalarT
sample_gamma(ScalarT alpha,
BaseSampler<AccscalarT, UniformSamplerT> standard_uniform,
BaseSampler<AccscalarT, NormalSamplerT> standard_normal) {
AccscalarT scale = 1.0f;
using MPTypeScalar = typename phi::dtype::MPTypeTrait<ScalarT>::Type;
using MPTypeAccscalar = typename phi::dtype::MPTypeTrait<AccscalarT>::Type;
MPTypeAccscalar mp_scale = static_cast<MPTypeAccscalar>(1.0f);
MPTypeScalar mp_alpha = static_cast<MPTypeScalar>(alpha);
// Boost alpha for higher acceptance probability.
if (alpha < 1.0f) {
if (alpha == 0.f) return 0.f;
scale *= COMPAT_POW(1 - standard_uniform.sample(), 1.0f / alpha);
alpha += 1.0f;
if (mp_alpha < 1.0f) {
if (mp_alpha == 0.f) return static_cast<ScalarT>(0.f);
MPTypeAccscalar mp_sample =
static_cast<MPTypeAccscalar>(standard_uniform.sample());
mp_scale *= COMPAT_POW(1 - mp_sample, 1.0f / mp_alpha);
mp_alpha += 1.0f;
}
// This implements the acceptance-rejection method of Marsaglia and Tsang
// (2000)
// doi:10.1145/358407.358414
const AccscalarT d = alpha - 1.0f / 3.0f;
const AccscalarT c = 1.0f / COMPAT_SQRT(9.0f * d);
const MPTypeAccscalar d = mp_alpha - 1.0f / 3.0f;
const MPTypeAccscalar c = 1.0f / COMPAT_SQRT(9.0f * d);
for (;;) {
AccscalarT x, y;
MPTypeAccscalar x, y;
do {
x = standard_normal.sample();
x = static_cast<MPTypeAccscalar>(standard_normal.sample());
y = 1.0f + c * x;
} while (y <= 0);
const AccscalarT v = y * y * y;
const AccscalarT u = 1 - standard_uniform.sample();
const AccscalarT xx = x * x;
const MPTypeAccscalar v = y * y * y;
const MPTypeAccscalar u =
1 - static_cast<MPTypeAccscalar>(standard_uniform.sample());
const MPTypeAccscalar xx = x * x;
if (u < 1.0f - 0.0331f * xx * xx)
return static_cast<ScalarT>(scale * d * v);
return static_cast<ScalarT>(mp_scale * d * v);
if (COMPAT_LOG(u) < 0.5f * xx + d * (1.0f - v + COMPAT_LOG(v)))
return static_cast<ScalarT>(scale * d * v);
return static_cast<ScalarT>(mp_scale * d * v);
}
}
......
......@@ -164,7 +164,10 @@ def _dirichlet(concentration, name=None):
else:
op_type = 'dirichlet'
check_variable_and_dtype(
concentration, 'concentration', ['float32', 'float64'], op_type
concentration,
'concentration',
['float16', 'float32', 'float64', 'uint16'],
op_type,
)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(
......
......@@ -20,7 +20,15 @@ import scipy.stats
import paddle
sys.path.append("../")
from eager_op_test import OpTest
import unittest
from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)
from paddle.fluid import core
paddle.enable_static()
......@@ -52,3 +60,89 @@ class TestDirichletOp(OpTest):
)[0],
0.01,
)
class TestDirichletFP16Op(OpTest):
# Because dirichlet random sample have not gradient, we skip gradient check.
no_need_check_grad = True
def setUp(self):
self.op_type = "dirichlet"
self.alpha = np.array((1.0, 2.0))
self.sample_shape = (100000, 2)
self.dtype = np.float16
self.inputs = {
'Alpha': np.broadcast_to(self.alpha, self.sample_shape).astype(
self.dtype
)
}
self.attrs = {}
self.outputs = {'Out': np.zeros(self.sample_shape).astype(self.dtype)}
def test_check_output(self):
self.check_output_customized(self._hypothesis_testing)
def _hypothesis_testing(self, outs):
self.assertEqual(outs[0].shape, self.sample_shape)
self.assertTrue(np.all(outs[0] > 0.0))
self.assertLess(
scipy.stats.kstest(
outs[0][:, 0],
# scipy dirichlet have not cdf, use beta to replace it.
scipy.stats.beta(a=self.alpha[0], b=self.alpha[1]).cdf,
)[0],
0.01,
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestDirichletBF16Op(OpTest):
# Because dirichlet random sample have not gradient, we skip gradient check.
no_need_check_grad = True
def setUp(self):
self.op_type = "dirichlet"
self.alpha = np.array((1.0, 2.0))
self.sample_shape = (10000, 2)
self.dtype = np.uint16
self.np_dtype = np.float32
self.inputs = {
'Alpha': np.broadcast_to(self.alpha, self.sample_shape).astype(
self.np_dtype
)
}
self.attrs = {}
self.outputs = {
'Out': np.zeros(self.sample_shape).astype(self.np_dtype)
}
self.inputs['Alpha'] = convert_float_to_uint16(self.inputs['Alpha'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place_customized(
self._hypothesis_testing, place=core.CUDAPlace(0)
)
def _hypothesis_testing(self, outs):
outs = convert_uint16_to_float(outs)
self.assertEqual(outs[0].shape, self.sample_shape)
self.assertTrue(np.all(outs[0] > 0.0))
self.assertLess(
scipy.stats.kstest(
outs[0][:, 0],
# scipy dirichlet have not cdf, use beta to replace it.
scipy.stats.beta(a=self.alpha[0], b=self.alpha[1]).cdf,
)[0],
0.3, # The bfloat16 test difference is below 0.3
)
if __name__ == '__main__':
unittest.main()
......@@ -15,11 +15,12 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from scipy.special import psi
import paddle
from paddle import fluid, static
from paddle.fluid import core
class TestDigammaOp(OpTest):
......@@ -55,6 +56,43 @@ class TestDigammaOpFp32(TestDigammaOp):
self.check_grad(['X'], 'Out')
class TestDigammaFP16Op(TestDigammaOp):
def init_dtype_type(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestDigammaBF16Op(OpTest):
def setUp(self):
# switch to static
paddle.enable_static()
self.op_type = 'digamma'
self.python_api = paddle.digamma
self.init_dtype_type()
shape = (5, 32)
data = np.random.random(shape).astype(self.np_dtype) + 1
self.inputs = {'X': convert_float_to_uint16(data)}
result = np.ones(shape).astype(self.np_dtype)
result = psi(data)
self.outputs = {'Out': convert_float_to_uint16(result)}
def init_dtype_type(self):
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
# bfloat16 needs to set the parameter place
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad_normal(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out')
class TestDigammaAPI(unittest.TestCase):
def setUp(self):
# switch to static
......
......@@ -4041,7 +4041,9 @@ def digamma(x, name=None):
if in_dygraph_mode():
return _C_ops.digamma(x)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'digamma')
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'digamma'
)
helper = LayerHelper('digamma', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='digamma', inputs={'X': x}, outputs={'Out': out})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册