提交 331e724d 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix a flaky test by computing half-precision gamma samples with float...

Fix a flaky test by computing half-precision gamma samples with float precision intermediate calculations.
Change: 125776264
上级 10ff074b
...@@ -266,6 +266,21 @@ class RandomUniformIntOp : public OpKernel { ...@@ -266,6 +266,21 @@ class RandomUniformIntOp : public OpKernel {
GuardedPhiloxRandom generator_; GuardedPhiloxRandom generator_;
}; };
namespace {
// We will compute half-precision Gamma samples with float precision
// intermediate calculations.
template <typename T>
struct GammaComputeType {
typedef T ComputeType;
};
template <>
struct GammaComputeType<Eigen::half> {
typedef float ComputeType;
};
} // namespace
// Samples from one or more gamma distributions. // Samples from one or more gamma distributions.
template <typename T> template <typename T>
class RandomGammaOp : public OpKernel { class RandomGammaOp : public OpKernel {
...@@ -307,8 +322,8 @@ class RandomGammaOp : public OpKernel { ...@@ -307,8 +322,8 @@ class RandomGammaOp : public OpKernel {
using random::PhiloxRandom; using random::PhiloxRandom;
typedef random::NormalDistribution<PhiloxRandom, T> Normal; typedef random::NormalDistribution<PhiloxRandom, CT> Normal;
typedef random::UniformDistribution<PhiloxRandom, T> Uniform; typedef random::UniformDistribution<PhiloxRandom, CT> Uniform;
// Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform // Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform
static constexpr int kReservedSamplesPerOutput = 256; static constexpr int kReservedSamplesPerOutput = 256;
...@@ -348,16 +363,16 @@ class RandomGammaOp : public OpKernel { ...@@ -348,16 +363,16 @@ class RandomGammaOp : public OpKernel {
int64 alpha_idx = output_idx / num_samples; int64 alpha_idx = output_idx / num_samples;
// Several calculations can be done on a per-alpha basis. // Several calculations can be done on a per-alpha basis.
const T alpha = alpha_flat[alpha_idx]; const CT alpha = CT(alpha_flat[alpha_idx]);
// For alpha<1, we add one to d=alpha-1/3, and multiply the final result // For alpha<1, we add one to d=alpha-1/3, and multiply the final result
// by uniform()^(1/alpha) // by uniform()^(1/alpha)
bool alpha_less_than_one = alpha < T(1); bool alpha_less_than_one = alpha < CT(1);
static const T kMinusOneThird = T(-1) / 3; static const CT kMinusOneThird = CT(-1) / 3;
static const T kTwoThirds = T(2) / 3; static const CT kTwoThirds = CT(2) / 3;
const T d = alpha + (alpha_less_than_one ? kTwoThirds : kMinusOneThird); const CT d =
static const T kOneThird = T(1) / 3; alpha + (alpha_less_than_one ? kTwoThirds : kMinusOneThird);
using Eigen::numext::sqrt; static const CT kOneThird = CT(1) / 3;
const T c = kOneThird / sqrt(d); const CT c = kOneThird / sqrt(d);
// Instead of +alpha_idx for each sample, we offset the pointer once. // Instead of +alpha_idx for each sample, we offset the pointer once.
auto samples_alpha_offset = samples_flat + alpha_idx; auto samples_alpha_offset = samples_flat + alpha_idx;
...@@ -382,9 +397,9 @@ class RandomGammaOp : public OpKernel { ...@@ -382,9 +397,9 @@ class RandomGammaOp : public OpKernel {
norm_result = normal(&gen); norm_result = normal(&gen);
} }
norm_remaining--; norm_remaining--;
const T x = norm_result[norm_remaining]; const CT x = norm_result[norm_remaining];
T v = T(1) + c * x; CT v = CT(1) + c * x;
if (v <= T(0)) { if (v <= CT(0)) {
continue; continue;
} }
v = v * v * v; v = v * v * v;
...@@ -393,16 +408,16 @@ class RandomGammaOp : public OpKernel { ...@@ -393,16 +408,16 @@ class RandomGammaOp : public OpKernel {
uniform_result = uniform(&gen); uniform_result = uniform(&gen);
} }
uniform_remaining--; uniform_remaining--;
T u = uniform_result[uniform_remaining]; CT u = uniform_result[uniform_remaining];
using Eigen::numext::log; using Eigen::numext::log;
// The first option in the if is a "squeeze" short-circuit to dodge // The first option in the if is a "squeeze" short-circuit to dodge
// the two logs. Magic constant sourced from the paper linked above. // the two logs. Magic constant sourced from the paper linked above.
// Upward of .91 of the area covered by the log inequality is // Upward of .91 of the area covered by the log inequality is
// covered by the squeeze as well (larger coverage for smaller // covered by the squeeze as well (larger coverage for smaller
// values of alpha). // values of alpha).
if ((u < T(1) - T(0.0331) * (x * x) * (x * x)) || if ((u < CT(1) - CT(0.0331) * (x * x) * (x * x)) ||
(log(u) < T(0.5) * x * x + d * (T(1) - v + log(v)))) { (log(u) < CT(0.5) * x * x + d * (CT(1) - v + log(v)))) {
T res = d * v; CT res = d * v;
if (alpha_less_than_one) { if (alpha_less_than_one) {
if (uniform_remaining == 0) { if (uniform_remaining == 0) {
uniform_remaining = Uniform::kResultElementCount; uniform_remaining = Uniform::kResultElementCount;
...@@ -410,9 +425,9 @@ class RandomGammaOp : public OpKernel { ...@@ -410,9 +425,9 @@ class RandomGammaOp : public OpKernel {
} }
uniform_remaining--; uniform_remaining--;
using Eigen::numext::pow; using Eigen::numext::pow;
res *= pow(uniform_result[uniform_remaining], T(1) / alpha); res *= pow(uniform_result[uniform_remaining], CT(1) / alpha);
} }
samples_alpha_offset[sample_idx * num_alphas] = res; samples_alpha_offset[sample_idx * num_alphas] = T(res);
break; break;
} }
} }
...@@ -433,6 +448,7 @@ class RandomGammaOp : public OpKernel { ...@@ -433,6 +448,7 @@ class RandomGammaOp : public OpKernel {
} }
private: private:
typedef typename GammaComputeType<T>::ComputeType CT;
GuardedPhiloxRandom generator_; GuardedPhiloxRandom generator_;
TF_DISALLOW_COPY_AND_ASSIGN(RandomGammaOp); TF_DISALLOW_COPY_AND_ASSIGN(RandomGammaOp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册