提交 ad3bcda5 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add bootstrap and Chebyshev options for determing if best split is dominating

the second best split in finished_nodes_op.
Change: 136640840
上级 83099f2b
......@@ -20,6 +20,7 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/work_sharder.h"
......@@ -29,9 +30,14 @@ using shape_inference::Dimension;
using shape_inference::InferenceContext;
using shape_inference::Shape;
using std::placeholders::_1;
using std::placeholders::_2;
using tensorforest::CheckTensorBounds;
using tensorforest::Sum;
using tensorforest::BestSplitDominatesClassification;
using tensorforest::BestSplitDominatesClassificationBootstrap;
using tensorforest::BestSplitDominatesClassificationChebyshev;
using tensorforest::BestSplitDominatesClassificationHoeffding;
using tensorforest::BestSplitDominatesRegression;
namespace {
......@@ -39,16 +45,14 @@ namespace {
struct EvaluateParams {
Tensor leaves;
Tensor node_to_accumulator;
Tensor split_sums;
Tensor split_squares;
Tensor accumulator_sums;
Tensor accumulator_squares;
Tensor birth_epochs;
int current_epoch;
float dominate_fraction;
int32 num_split_after_samples;
int32 min_split_samples;
bool regression;
bool need_random;
int64 random_seed;
std::function<bool(int, random::SimplePhilox*)> dominate_method;
};
void Evaluate(const EvaluateParams& params, mutex* mutex, int32 start,
......@@ -65,6 +69,13 @@ void Evaluate(const EvaluateParams& params, mutex* mutex, int32 start,
std::vector<int32> finished_leaves;
std::vector<int32> stale;
std::unique_ptr<random::SimplePhilox> simple_philox;
if (params.need_random) {
random::PhiloxRandom rnd_gen(params.random_seed);
simple_philox.reset(new random::SimplePhilox(&rnd_gen));
}
for (int32 i = start; i < end; i++) {
const int32 leaf = internal::SubtleMustCopy(leaves(i));
if (leaf == -1) {
......@@ -103,17 +114,7 @@ void Evaluate(const EvaluateParams& params, mutex* mutex, int32 start,
continue;
}
bool finished = false;
if (params.regression) {
finished = BestSplitDominatesRegression(
params.accumulator_sums, params.accumulator_squares,
params.split_sums, params.split_squares, accumulator);
} else {
finished = BestSplitDominatesClassification(
params.accumulator_sums, params.split_sums, accumulator,
params.dominate_fraction);
}
bool finished = params.dominate_method(accumulator, simple_philox.get());
if (finished) {
finished_leaves.push_back(leaf);
}
......@@ -130,6 +131,12 @@ REGISTER_OP("FinishedNodes")
.Attr("num_split_after_samples: int")
.Attr("min_split_samples: int")
.Attr("dominate_fraction: float = 0.99")
// TODO(thomaswc): Test out bootstrap on several datasets, confirm it
// works well, make it the default.
.Attr(
"dominate_method:"
" {'none', 'hoeffding', 'bootstrap', 'chebyshev'} = 'hoeffding'")
.Attr("random_seed: int = 0")
.Input("leaves: int32")
.Input("node_to_accumulator: int32")
.Input("split_sums: float")
......@@ -194,6 +201,9 @@ class FinishedNodes : public OpKernel {
"min_split_samples", &min_split_samples_));
OP_REQUIRES_OK(context, context->GetAttr(
"dominate_fraction", &dominate_fraction_));
OP_REQUIRES_OK(context,
context->GetAttr("dominate_method", &dominate_method_));
OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
}
void Compute(OpKernelContext* context) override {
......@@ -249,15 +259,45 @@ class FinishedNodes : public OpKernel {
EvaluateParams params;
params.leaves = leaf_tensor;
params.node_to_accumulator = node_to_accumulator;
params.split_sums = split_sums;
params.split_squares = split_squares;
params.accumulator_sums = accumulator_sums;
params.birth_epochs = birth_epochs;
params.current_epoch = epoch;
params.dominate_fraction = dominate_fraction_;
params.min_split_samples = min_split_samples_;
params.num_split_after_samples = num_split_after_samples_;
params.regression = regression_;
params.need_random = false;
if (regression_) {
params.dominate_method =
std::bind(&BestSplitDominatesRegression, accumulator_sums,
accumulator_squares, split_sums, split_squares, _1);
} else {
if (dominate_method_ == "none") {
params.dominate_method = [](int, random::SimplePhilox*) {
return false;
};
} else if (dominate_method_ == "hoeffding") {
params.dominate_method =
std::bind(&BestSplitDominatesClassificationHoeffding,
accumulator_sums, split_sums, _1, dominate_fraction_);
} else if (dominate_method_ == "chebyshev") {
params.dominate_method =
std::bind(&BestSplitDominatesClassificationChebyshev,
accumulator_sums, split_sums, _1, dominate_fraction_);
} else if (dominate_method_ == "bootstrap") {
params.need_random = true;
params.random_seed = random_seed_;
if (params.random_seed == 0) {
params.random_seed = static_cast<uint64>(std::clock());
}
params.dominate_method =
std::bind(&BestSplitDominatesClassificationBootstrap,
accumulator_sums, split_sums, _1, dominate_fraction_, _2);
} else {
LOG(FATAL) << "Unknown dominate method " << dominate_method_;
}
}
std::vector<int32> finished_leaves;
std::vector<int32> stale;
......@@ -300,6 +340,8 @@ class FinishedNodes : public OpKernel {
int32 num_split_after_samples_;
int32 min_split_samples_;
float dominate_fraction_;
string dominate_method_;
int32 random_seed_;
};
REGISTER_KERNEL_BUILDER(Name("FinishedNodes").Device(DEVICE_CPU),
......
......@@ -13,7 +13,10 @@
// limitations under the License.
// =============================================================================
#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
#include <algorithm>
#include <cfloat>
#include "tensorflow/core/lib/random/distribution_sampler.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
......@@ -203,10 +206,104 @@ bool BestSplitDominatesRegression(
return false;
}
bool BestSplitDominatesClassification(
const Tensor& total_counts,
const Tensor& split_counts, int32 accumulator,
float dominate_fraction) {
// We return the Gini Impurity of the bootstrap sample as an int rather
// than a float, so that we can more easily check for ties.
int BootstrapGini(int n, int s, const random::DistributionSampler& ds,
random::SimplePhilox* rand) {
std::vector<int> counts(s, 0);
for (int i = 0; i < n; i++) {
int j = ds.Sample(rand);
counts[j] += 1;
}
int g = 0;
for (int j = 0; j < s; j++) {
g += counts[j] * counts[j];
}
// The true gini is 1 + (-g) / n^2
return -g;
}
// Populate *weights with the smoothed per-class frequencies needed to
// initialize a DistributionSampler. Returns the total number of samples
// seen by this accumulator.
int MakeBootstrapWeights(const Tensor& total_counts, const Tensor& split_counts,
int32 accumulator, int index,
std::vector<float>* weights) {
const int32 num_classes =
static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
auto tc = total_counts.tensor<float, 2>();
auto lc = split_counts.tensor<float, 3>();
int n = tc(accumulator, 0);
float denom = static_cast<float>(n) + static_cast<float>(num_classes);
weights->resize(num_classes * 2);
for (int i = 0; i < num_classes; i++) {
// Use the Laplace smoothed per-class probabilities when generating the
// bootstrap samples.
float left_count = lc(accumulator, index, i + 1);
(*weights)[i] = (left_count + 1.0) / denom;
float right_count = tc(accumulator, i + 1) - left_count;
(*weights)[num_classes + i] = (right_count + 1.0) / denom;
}
return n;
}
bool BestSplitDominatesClassificationBootstrap(const Tensor& total_counts,
const Tensor& split_counts,
int32 accumulator,
float dominate_fraction,
random::SimplePhilox* rand) {
float best_score;
float second_best_score;
int best_feature_index;
int second_best_index;
GetTwoBestClassification(total_counts, split_counts, accumulator, &best_score,
&best_feature_index, &second_best_score,
&second_best_index);
std::vector<float> weights1;
int n1 = MakeBootstrapWeights(total_counts, split_counts, accumulator,
best_feature_index, &weights1);
random::DistributionSampler ds1(weights1);
std::vector<float> weights2;
int n2 = MakeBootstrapWeights(total_counts, split_counts, accumulator,
second_best_index, &weights2);
random::DistributionSampler ds2(weights2);
const int32 num_classes =
static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
float p = 1.0 - dominate_fraction;
if (p <= 0 || p > 1.0) {
LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction;
}
int bootstrap_samples = 1;
while (p < 1.0) {
bootstrap_samples += 1;
p = p * 2;
}
for (int i = 0; i < bootstrap_samples; i++) {
int g1 = BootstrapGini(n1, 2 * num_classes, ds1, rand);
int g2 = BootstrapGini(n2, 2 * num_classes, ds2, rand);
if (g2 <= g1) {
return false;
}
}
return true;
}
bool BestSplitDominatesClassificationHoeffding(const Tensor& total_counts,
const Tensor& split_counts,
int32 accumulator,
float dominate_fraction) {
float best_score;
float second_best_score;
int best_feature_index;
......@@ -226,8 +323,6 @@ bool BestSplitDominatesClassification(
// Each term in the Gini impurity can range from 0 to 0.5 * 0.5.
float range = 0.25 * static_cast<float>(num_classes) * n;
// TODO(thomaswc): The hoeffding bound is actually only valid for linear
// functions, which the Gini impurity is not. Come up with a better bound!
float hoeffding_bound =
range * sqrt(log(1.0 / (1.0 - dominate_fraction)) / (2.0 * n));
......@@ -238,6 +333,199 @@ bool BestSplitDominatesClassification(
return (second_best_score - best_score) > hoeffding_bound;
}
double DirichletCovarianceTrace(const Tensor& total_counts,
const Tensor& split_counts, int32 accumulator,
int index) {
const int32 num_classes =
static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
auto tc = total_counts.tensor<float, 2>();
auto lc = split_counts.tensor<float, 3>();
double leftc = 0.0;
double leftc2 = 0.0;
double rightc = 0.0;
double rightc2 = 0.0;
for (int i = 1; i <= num_classes; i++) {
double l = lc(accumulator, index, i) + 1.0;
leftc += l;
leftc2 += l * l;
double r = tc(accumulator, i) - lc(accumulator, index, i) + 1.0;
rightc += r;
rightc2 += r * r;
}
double left_trace = (1.0 - leftc2 / (leftc * leftc)) / (leftc + 1.0);
double right_trace = (1.0 - rightc2 / (rightc * rightc)) / (rightc + 1.0);
return left_trace + right_trace;
}
void getDirichletMean(const Tensor& total_counts, const Tensor& split_counts,
int32 accumulator, int index, std::vector<float>* mu) {
const int32 num_classes =
static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
mu->resize(num_classes * 2);
auto tc = total_counts.tensor<float, 2>();
auto lc = split_counts.tensor<float, 3>();
double total = tc(accumulator, 0);
for (int i = 0; i < num_classes; i++) {
double l = lc(accumulator, index, i + 1);
mu->at(i) = (l + 1.0) / (total + num_classes);
double r = tc(accumulator, i) - l;
mu->at(i + num_classes) = (r + 1.) / (total + num_classes);
}
}
// Given lambda3, returns the distance from (mu1, mu2) to the surface.
double getDistanceFromLambda3(double lambda3, const std::vector<float>& mu1,
const std::vector<float>& mu2) {
if (fabs(lambda3) == 1.0) {
return 0.0;
}
int n = mu1.size();
double lambda1 = -2.0 * lambda3 / n;
double lambda2 = 2.0 * lambda3 / n;
// From below,
// x = (lambda_1 1 + 2 mu1) / (2 - 2 lambda_3)
// y = (lambda_2 1 + 2 mu2) / (2 + 2 lambda_3)
double dist = 0.0;
for (int i = 0; i < mu1.size(); i++) {
double diff = (lambda1 + 2.0 * mu1[i]) / (2.0 - 2.0 * lambda3) - mu1[i];
dist += diff * diff;
diff = (lambda2 + 2.0 * mu2[i]) / (2.0 + 2.0 * lambda3) - mu2[i];
dist += diff * diff;
}
return dist;
}
// Returns the distance between (mu1, mu2) and (x, y), where (x, y) is the
// nearest point that lies on the surface defined by
// {x dot 1 = 1, y dot 1 = 1, x dot x - y dot y = 0}.
double getChebyshevEpsilon(const std::vector<float>& mu1,
const std::vector<float>& mu2) {
// Math time!!
// We are trying to minimize d = |mu1 - x|^2 + |mu2 - y|^2 over the surface.
// Using Langrange multipliers, we get
// partial d / partial x = -2 mu1 + 2 x = lambda_1 1 + 2 lambda_3 x
// partial d / partial y = -2 mu2 + 2 y = lambda_2 1 - 2 lambda_3 y
// or
// x = (lambda_1 1 + 2 mu1) / (2 - 2 lambda_3)
// y = (lambda_2 1 + 2 mu2) / (2 + 2 lambda_3)
// which implies
// 2 - 2 lambda_3 = lambda_1 1 dot 1 + 2 mu1 dot 1
// 2 + 2 lambda_3 = lambda_2 1 dot 1 + 2 mu2 dot 1
// |lambda_1 1 + 2 mu1|^2 (2 + 2 lambda_3)^2 =
// |lambda_2 1 + 2 mu2|^2 (2 - 2 lambda_3)^2
// So solving for the lambda's and using the fact that
// mu1 dot 1 = 1 and mu2 dot 1 = 1,
// lambda_1 = -2 lambda_3 / (1 dot 1)
// lambda_2 = 2 lambda_3 / (1 dot 1)
// and (letting n = 1 dot 1)
// | - lambda_3 1 + n mu1 |^2 (1 + lambda_3)^2 =
// | lambda_3 1 + n mu2 |^2 (1 - lambda_3)^2
// or
// (lambda_3^2 n - 2 n lambda_3 + n^2 mu1 dot mu1)(1 + lambda_3)^2 =
// (lambda_3^2 n + 2 n lambda_3 + n^2 mu2 dot mu2)(1 - lambda_3)^2
// or
// (lambda_3^2 - 2 lambda_3 + n mu1 dot mu1)(1 + 2 lambda_3 + lambda_3^2) =
// (lambda_3^2 + 2 lambda_3 + n mu2 dot mu2)(1 - 2 lambda_3 + lambda_3^2)
// or
// lambda_3^2 - 2 lambda_3 + n mu1 dot mu1
// + 2 lambda_3^3 - 2 lambda_3^2 + 2n lambda_3 mu1 dot mu1
// + lambda_3^4 - 2 lambda_3^3 + n lambda_3^2 mu1 dot mu1
// =
// lambda_3^2 + 2 lambda_3 + n mu2 dot mu2
// - 2 lambda_3^3 -4 lambda_3^2 - 2n lambda_3 mu2 dot mu2
// + lambda_3^4 + 2 lambda_3^3 + n lambda_3^2 mu2 dot mu2
// or
// - 2 lambda_3 + n mu1 dot mu1
// - 2 lambda_3^2 + 2n lambda_3 mu1 dot mu1
// + n lambda_3^2 mu1 dot mu1
// =
// + 2 lambda_3 + n mu2 dot mu2
// -4 lambda_3^2 - 2n lambda_3 mu2 dot mu2
// + n lambda_3^2 mu2 dot mu2
// or
// lambda_3^2 (2 + n mu1 dot mu1 + n mu2 dot mu2)
// + lambda_3 (2n mu1 dot mu1 + 2n mu2 dot mu2 - 4)
// + n mu1 dot mu1 - n mu2 dot mu2 = 0
// which can be solved using the quadratic formula.
int n = mu1.size();
double len1 = 0.0;
for (float m : mu1) {
len1 += m * m;
}
double len2 = 0.0;
for (float m : mu2) {
len2 += m * m;
}
double a = 2 + n * (len1 + len2);
double b = 2 * n * (len1 + len2) - 4;
double c = n * (len1 - len2);
double discrim = b * b - 4 * a * c;
if (discrim < 0.0) {
LOG(WARNING) << "Negative discriminant " << discrim;
return 0.0;
}
double sdiscrim = sqrt(discrim);
// TODO(thomaswc): Analyze whetever one of these is always closer.
double v1 = (-b + sdiscrim) / (2 * a);
double v2 = (-b - sdiscrim) / (2 * a);
double dist1 = getDistanceFromLambda3(v1, mu1, mu2);
double dist2 = getDistanceFromLambda3(v2, mu1, mu2);
return std::min(dist1, dist2);
}
bool BestSplitDominatesClassificationChebyshev(const Tensor& total_counts,
const Tensor& split_counts,
int32 accumulator,
float dominate_fraction) {
float best_score;
float second_best_score;
int best_feature_index;
int second_best_index;
VLOG(1) << "BSDC for accumulator " << accumulator;
GetTwoBestClassification(total_counts, split_counts, accumulator, &best_score,
&best_feature_index, &second_best_score,
&second_best_index);
VLOG(1) << "Best score = " << best_score;
VLOG(1) << "2nd best score = " << second_best_score;
const int32 num_classes =
static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
const float n = total_counts.Slice(accumulator, accumulator + 1)
.unaligned_flat<float>()(0);
VLOG(1) << "num_classes = " << num_classes;
VLOG(1) << "n = " << n;
double trace = DirichletCovarianceTrace(total_counts, split_counts,
accumulator, best_feature_index) +
DirichletCovarianceTrace(total_counts, split_counts,
accumulator, second_best_index);
std::vector<float> mu1;
getDirichletMean(total_counts, split_counts, accumulator, best_feature_index,
&mu1);
std::vector<float> mu2;
getDirichletMean(total_counts, split_counts, accumulator, second_best_index,
&mu2);
double epsilon = getChebyshevEpsilon(mu1, mu2);
if (epsilon == 0.0) {
return false;
}
double dirichlet_bound = 1.0 - trace / (epsilon * epsilon);
return dirichlet_bound > dominate_fraction;
}
bool DecideNode(const Tensor& point, int32 feature, float bias,
DataColumnTypes type) {
const auto p = point.unaligned_flat<float>();
......
......@@ -20,6 +20,7 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
......@@ -99,12 +100,27 @@ bool BestSplitDominatesRegression(
const Tensor& split_sums, const Tensor& split_squares,
int32 accumulator);
// Performs booststrap_samples bootstrap samples of the best split's class
// counts and the second best splits's class counts, and returns true if at
// least dominate_fraction of the time, the former has a better (lower)
// Gini impurity. Does not take over ownership of *rand.
bool BestSplitDominatesClassificationBootstrap(
const Tensor& total_counts, const Tensor& split_counts, int32 accumulator,
float dominate_fraction, tensorflow::random::SimplePhilox* rand);
// Returns true if the best split's Gini impurity is sufficiently smaller than
// that of the next best split.
bool BestSplitDominatesClassification(
const Tensor& total_counts,
const Tensor& split_counts, int32 accumulator,
float dominate_fraction);
// that of the next best split, as measured by the Hoeffding Tree bound.
bool BestSplitDominatesClassificationHoeffding(const Tensor& total_counts,
const Tensor& split_counts,
int32 accumulator,
float dominate_fraction);
// Returns true if the best split's Gini impurity is sufficiently smaller than
// that of the next best split, as measured by a Chebyshev bound.
bool BestSplitDominatesClassificationChebyshev(const Tensor& total_counts,
const Tensor& split_counts,
int32 accumulator,
float dominate_fraction);
// Initializes everything in the given tensor to the given value.
template <typename T>
......
......@@ -103,16 +103,63 @@ class FinishedNodesTest(test_util.TensorFlowTestCase):
self.assertAllEqual([], finished.eval())
self.assertAllEqual([], stale.eval())
def testEarlyDominates(self):
def testEarlyDominatesHoeffding(self):
with self.test_session():
finished, stale = self.ops.finished_nodes(
self.leaves, self.node_map, self.split_sums,
self.split_squares, self.accumulator_sums, self.accumulator_squares,
self.birth_epochs, self.current_epoch,
regression=False, num_split_after_samples=10, min_split_samples=5)
self.leaves,
self.node_map,
self.split_sums,
self.split_squares,
self.accumulator_sums,
self.accumulator_squares,
self.birth_epochs,
self.current_epoch,
dominate_method='hoeffding',
regression=False,
num_split_after_samples=10,
min_split_samples=5)
self.assertAllEqual([4], finished.eval())
self.assertAllEqual([], stale.eval())
def testEarlyDominatesBootstrap(self):
with self.test_session():
finished, stale = self.ops.finished_nodes(
self.leaves,
self.node_map,
self.split_sums,
self.split_squares,
self.accumulator_sums,
self.accumulator_squares,
self.birth_epochs,
self.current_epoch,
dominate_method='bootstrap',
regression=False,
num_split_after_samples=10,
min_split_samples=5)
self.assertAllEqual([4], finished.eval())
self.assertAllEqual([], stale.eval())
def testEarlyDominatesChebyshev(self):
with self.test_session():
finished, stale = self.ops.finished_nodes(
self.leaves,
self.node_map,
self.split_sums,
self.split_squares,
self.accumulator_sums,
self.accumulator_squares,
self.birth_epochs,
self.current_epoch,
dominate_method='chebyshev',
regression=False,
num_split_after_samples=10,
min_split_samples=5)
self.assertAllEqual([4], finished.eval())
self.assertAllEqual([], stale.eval())
if __name__ == '__main__':
googletest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册