提交 331e724d 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix a flaky test by computing half-precision gamma samples with float...

Fix a flaky test by computing half-precision gamma samples with float precision intermediate calculations.
Change: 125776264
上级 10ff074b
......@@ -266,6 +266,21 @@ class RandomUniformIntOp : public OpKernel {
GuardedPhiloxRandom generator_;
};
namespace {
// We will compute half-precision Gamma samples with float precision
// intermediate calculations.
template <typename T>
struct GammaComputeType {
typedef T ComputeType;
};
template <>
struct GammaComputeType<Eigen::half> {
typedef float ComputeType;
};
} // namespace
// Samples from one or more gamma distributions.
template <typename T>
class RandomGammaOp : public OpKernel {
......@@ -307,8 +322,8 @@ class RandomGammaOp : public OpKernel {
using random::PhiloxRandom;
typedef random::NormalDistribution<PhiloxRandom, T> Normal;
typedef random::UniformDistribution<PhiloxRandom, T> Uniform;
typedef random::NormalDistribution<PhiloxRandom, CT> Normal;
typedef random::UniformDistribution<PhiloxRandom, CT> Uniform;
// Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform
static constexpr int kReservedSamplesPerOutput = 256;
......@@ -348,16 +363,16 @@ class RandomGammaOp : public OpKernel {
int64 alpha_idx = output_idx / num_samples;
// Several calculations can be done on a per-alpha basis.
const T alpha = alpha_flat[alpha_idx];
const CT alpha = CT(alpha_flat[alpha_idx]);
// For alpha<1, we add one to d=alpha-1/3, and multiply the final result
// by uniform()^(1/alpha)
bool alpha_less_than_one = alpha < T(1);
static const T kMinusOneThird = T(-1) / 3;
static const T kTwoThirds = T(2) / 3;
const T d = alpha + (alpha_less_than_one ? kTwoThirds : kMinusOneThird);
static const T kOneThird = T(1) / 3;
using Eigen::numext::sqrt;
const T c = kOneThird / sqrt(d);
bool alpha_less_than_one = alpha < CT(1);
static const CT kMinusOneThird = CT(-1) / 3;
static const CT kTwoThirds = CT(2) / 3;
const CT d =
alpha + (alpha_less_than_one ? kTwoThirds : kMinusOneThird);
static const CT kOneThird = CT(1) / 3;
const CT c = kOneThird / sqrt(d);
// Instead of +alpha_idx for each sample, we offset the pointer once.
auto samples_alpha_offset = samples_flat + alpha_idx;
......@@ -382,9 +397,9 @@ class RandomGammaOp : public OpKernel {
norm_result = normal(&gen);
}
norm_remaining--;
const T x = norm_result[norm_remaining];
T v = T(1) + c * x;
if (v <= T(0)) {
const CT x = norm_result[norm_remaining];
CT v = CT(1) + c * x;
if (v <= CT(0)) {
continue;
}
v = v * v * v;
......@@ -393,16 +408,16 @@ class RandomGammaOp : public OpKernel {
uniform_result = uniform(&gen);
}
uniform_remaining--;
T u = uniform_result[uniform_remaining];
CT u = uniform_result[uniform_remaining];
using Eigen::numext::log;
// The first option in the if is a "squeeze" short-circuit to dodge
// the two logs. Magic constant sourced from the paper linked above.
// Upward of .91 of the area covered by the log inequality is
// covered by the squeeze as well (larger coverage for smaller
// values of alpha).
if ((u < T(1) - T(0.0331) * (x * x) * (x * x)) ||
(log(u) < T(0.5) * x * x + d * (T(1) - v + log(v)))) {
T res = d * v;
if ((u < CT(1) - CT(0.0331) * (x * x) * (x * x)) ||
(log(u) < CT(0.5) * x * x + d * (CT(1) - v + log(v)))) {
CT res = d * v;
if (alpha_less_than_one) {
if (uniform_remaining == 0) {
uniform_remaining = Uniform::kResultElementCount;
......@@ -410,9 +425,9 @@ class RandomGammaOp : public OpKernel {
}
uniform_remaining--;
using Eigen::numext::pow;
res *= pow(uniform_result[uniform_remaining], T(1) / alpha);
res *= pow(uniform_result[uniform_remaining], CT(1) / alpha);
}
samples_alpha_offset[sample_idx * num_alphas] = res;
samples_alpha_offset[sample_idx * num_alphas] = T(res);
break;
}
}
......@@ -433,6 +448,7 @@ class RandomGammaOp : public OpKernel {
}
private:
typedef typename GammaComputeType<T>::ComputeType CT;
GuardedPhiloxRandom generator_;
TF_DISALLOW_COPY_AND_ASSIGN(RandomGammaOp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册