提交 255b56f9 编写于 作者: D DB Tsai 提交者: Xiangrui Meng

[SPARK-2479][MLlib] Comparing floating-point numbers using relative error in UnitTests

Floating point math is not exact, and most floating-point numbers end up being slightly imprecise due to rounding errors.

Simple values like 0.1 cannot be precisely represented using binary floating point numbers, and the limited precision of floating point numbers means that slight changes in the order of operations or the precision of intermediates can change the result.

That means that comparing two floats to see if they are equal is usually not what we want. As long as this imprecision stays small, it can usually be ignored.

Based on discussion in the community, we have implemented two different APIs for relative tolerance, and absolute tolerance. It makes sense that test writers should know which one they need depending on their circumstances.

Developers also need to explicitly specify the eps, and there is no default value which will sometimes cause confusion.

When comparing against zero using relative tolerance, a exception will be raised to warn users that it's meaningless.

For relative tolerance, users can now write

    assert(23.1 ~== 23.52 relTol 0.02)
    assert(23.1 ~== 22.74 relTol 0.02)
    assert(23.1 ~= 23.52 relTol 0.02)
    assert(23.1 ~= 22.74 relTol 0.02)
    assert(!(23.1 !~= 23.52 relTol 0.02))
    assert(!(23.1 !~= 22.74 relTol 0.02))

    // This will throw exception with the following message.
    // "Did not expect 23.1 and 23.52 to be within 0.02 using relative tolerance."
    assert(23.1 !~== 23.52 relTol 0.02)

    // "Expected 23.1 and 22.34 to be within 0.02 using relative tolerance."
    assert(23.1 ~== 22.34 relTol 0.02)

For absolute error,

    assert(17.8 ~== 17.99 absTol 0.2)
    assert(17.8 ~== 17.61 absTol 0.2)
    assert(17.8 ~= 17.99 absTol 0.2)
    assert(17.8 ~= 17.61 absTol 0.2)
    assert(!(17.8 !~= 17.99 absTol 0.2))
    assert(!(17.8 !~= 17.61 absTol 0.2))

    // This will throw exception with the following message.
    // "Did not expect 17.8 and 17.99 to be within 0.2 using absolute error."
    assert(17.8 !~== 17.99 absTol 0.2)

    // "Expected 17.8 and 17.59 to be within 0.2 using absolute error."
    assert(17.8 ~== 17.59 absTol 0.2)

Authors:
  DB Tsai <dbtsaialpinenow.com>
  Marek Kolodziej <marekalpinenow.com>

Author: DB Tsai <dbtsai@alpinenow.com>

Closes #1425 from dbtsai/SPARK-2479_comparing_floating_point and squashes the following commits:

8c7cbcc [DB Tsai] Alpine Data Labs
上级 2b8d89e3
......@@ -26,6 +26,7 @@ import org.scalatest.Matchers
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
object LogisticRegressionSuite {
......@@ -81,9 +82,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
val model = lr.run(testRDD)
// Test the weights
val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
assert(model.weights(0) ~== -1.52 relTol 0.01)
assert(model.intercept ~== 2.00 relTol 0.01)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
......@@ -113,9 +113,9 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
val model = lr.run(testRDD, initialWeights)
val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
// Test the weights
assert(model.weights(0) ~== -1.50 relTol 0.01)
assert(model.intercept ~== 1.97 relTol 0.01)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
......
......@@ -21,8 +21,9 @@ import scala.util.Random
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
class KMeansSuite extends FunSuite with LocalSparkContext {
......@@ -41,26 +42,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// centered at the mean of the points
var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(
data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
}
test("no distinct points") {
......@@ -104,26 +105,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.size === 1)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
}
test("single cluster with sparse data") {
......@@ -149,31 +150,39 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0)))
var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
assert(model.clusterCenters.head ~== center absTol 1E-5)
data.unpersist()
}
test("k-means|| initialization") {
case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] {
@Override def compare(that: VectorWithCompare): Int = {
if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) >
that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1
}
}
val points = Seq(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
......@@ -188,15 +197,19 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// unselected point as long as it hasn't yet selected all of them
var model = KMeans.train(rdd, k = 5, maxIterations = 1)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
assert(model.clusterCenters.sortBy(VectorWithCompare(_))
.zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
// Iterations of Lloyd's should not change the answer either
model = KMeans.train(rdd, k = 5, maxIterations = 10)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
assert(model.clusterCenters.sortBy(VectorWithCompare(_))
.zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
// Neither should more runs
model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
assert(model.clusterCenters.sortBy(VectorWithCompare(_))
.zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
}
test("two clusters") {
......
......@@ -20,27 +20,28 @@ package org.apache.spark.mllib.evaluation
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.TestingUtils._
class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
test("auc computation") {
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
val auc = 4.0
assert(AreaUnderCurve.of(curve) === auc)
assert(AreaUnderCurve.of(curve) ~== auc absTol 1E-5)
val rddCurve = sc.parallelize(curve, 2)
assert(AreaUnderCurve.of(rddCurve) == auc)
assert(AreaUnderCurve.of(rddCurve) ~== auc absTol 1E-5)
}
test("auc of an empty curve") {
val curve = Seq.empty[(Double, Double)]
assert(AreaUnderCurve.of(curve) === 0.0)
assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5)
val rddCurve = sc.parallelize(curve, 2)
assert(AreaUnderCurve.of(rddCurve) === 0.0)
assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5)
}
test("auc of a curve with a single point") {
val curve = Seq((1.0, 1.0))
assert(AreaUnderCurve.of(curve) === 0.0)
assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5)
val rddCurve = sc.parallelize(curve, 2)
assert(AreaUnderCurve.of(rddCurve) === 0.0)
assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5)
}
}
......@@ -20,25 +20,14 @@ package org.apache.spark.mllib.evaluation
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals
import org.apache.spark.mllib.util.TestingUtils._
class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
// TODO: move utility functions to TestingUtils.
def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = {
actual.zip(expected).forall { case (x1, x2) =>
x1.almostEquals(x2)
}
}
def elementsAlmostEqual(
actual: Seq[(Double, Double)],
expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = {
actual.zip(expected).forall { case ((x1, y1), (x2, y2)) =>
x1.almostEquals(x2) && y1.almostEquals(y2)
}
}
def cond2(x: ((Double, Double), (Double, Double))): Boolean =
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
......@@ -57,16 +46,17 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
val pr = recall.zip(precision)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold))
assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve))
assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve)))
assert(elementsAlmostEqual(metrics.pr().collect(), prCurve))
assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve)))
assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1)))
assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2)))
assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision)))
assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall)))
assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
}
}
......@@ -25,6 +25,7 @@ import org.scalatest.{FunSuite, Matchers}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
object GradientDescentSuite {
......@@ -126,19 +127,14 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers
val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD(
dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept)
def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
math.abs(x - y) / (math.abs(y) + 1e-15) < tol
}
assert(compareDouble(
loss1(0),
loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) +
math.pow(initialWeightsWithIntercept(1), 2)) / 2),
assert(
loss1(0) ~= (loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) +
math.pow(initialWeightsWithIntercept(1), 2)) / 2) absTol 1E-5,
"""For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""")
assert(
compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) &&
compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)),
(newWeights1(0) ~= (newWeights0(0) - initialWeightsWithIntercept(0)) absTol 1E-5) &&
(newWeights1(1) ~= (newWeights0(1) - initialWeightsWithIntercept(1)) absTol 1E-5),
"The different between newWeights with/without regularization " +
"should be initialWeightsWithIntercept.")
}
......
......@@ -24,6 +24,7 @@ import org.scalatest.{FunSuite, Matchers}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
......@@ -49,10 +50,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
lazy val dataRDD = sc.parallelize(data, 2).cache()
def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
math.abs(x - y) / (math.abs(y) + 1e-15) < tol
}
test("LBFGS loss should be decreasing and match the result of Gradient Descent.") {
val regParam = 0
......@@ -126,15 +123,15 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
miniBatchFrac,
initialWeightsWithIntercept)
assert(compareDouble(lossGD(0), lossLBFGS(0)),
assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5,
"The first losses of LBFGS and GD should be the same.")
// The 2% difference here is based on observation, but is not theoretically guaranteed.
assert(compareDouble(lossGD.last, lossLBFGS.last, 0.02),
assert(lossGD.last ~= lossLBFGS.last relTol 0.02,
"The last losses of LBFGS and GD should be within 2% difference.")
assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) &&
compareDouble(weightLBFGS(1), weightGD(1), 0.02),
assert(
(weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02),
"The weight differences between LBFGS and GD should be within 2%.")
}
......@@ -226,8 +223,8 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
initialWeightsWithIntercept)
// for class LBFGS and the optimize method, we only look at the weights
assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) &&
compareDouble(weightLBFGS(1), weightGD(1), 0.02),
assert(
(weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02),
"The weight differences between LBFGS and GD should be within 2%.")
}
}
......
......@@ -21,7 +21,9 @@ import scala.util.Random
import org.scalatest.FunSuite
import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas}
import org.jblas.{DoubleMatrix, SimpleBlas}
import org.apache.spark.mllib.util.TestingUtils._
class NNLSSuite extends FunSuite {
/** Generate an NNLS problem whose optimal solution is the all-ones vector. */
......@@ -73,7 +75,7 @@ class NNLSSuite extends FunSuite {
val ws = NNLS.createWorkspace(n)
val x = NNLS.solve(ata, atb, ws)
for (i <- 0 until n) {
assert(Math.abs(x(i) - goodx(i)) < 1e-3)
assert(x(i) ~== goodx(i) absTol 1E-3)
assert(x(i) >= 0)
}
}
......
......@@ -89,15 +89,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
.add(Vectors.dense(-1.0, 0.0, 6.0))
.add(Vectors.dense(3.0, -3.0, 0.0))
assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch")
assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch")
assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch")
assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 2)
}
......@@ -107,15 +107,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
.add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0))))
.add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0))))
assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch")
assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch")
assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch")
assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 2)
}
......@@ -129,17 +129,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
.add(Vectors.dense(1.7, -0.6, 0.0))
.add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
assert(summarizer.mean.almostEquals(
Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
assert(summarizer.mean ~==
Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch")
assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch")
assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch")
assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance.almostEquals(
Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
assert(summarizer.variance ~==
Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 6)
}
......@@ -157,17 +157,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
val summarizer = summarizer1.merge(summarizer2)
assert(summarizer.mean.almostEquals(
Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
assert(summarizer.mean ~==
Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch")
assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch")
assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch")
assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance.almostEquals(
Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
assert(summarizer.variance ~==
Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 6)
}
......@@ -186,24 +186,24 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer)
assert(summarizer3.count === 0)
assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
assert(summarizer1.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch")
assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
assert(summarizer2.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch")
assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
assert(summarizer1.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch")
assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
assert(summarizer2.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch")
assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
assert(summarizer1.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch")
assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
assert(summarizer2.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch")
assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
assert(summarizer1.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
assert(summarizer2.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
assert(summarizer1.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
}
}
......@@ -18,28 +18,155 @@
package org.apache.spark.mllib.util
import org.apache.spark.mllib.linalg.Vector
import org.scalatest.exceptions.TestFailedException
object TestingUtils {
val ABS_TOL_MSG = " using absolute tolerance"
val REL_TOL_MSG = " using relative tolerance"
/**
* Private helper function for comparing two values using relative tolerance.
* Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue,
* the relative tolerance is meaningless, so the exception will be raised to warn users.
*/
private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
val absX = math.abs(x)
val absY = math.abs(y)
val diff = math.abs(x - y)
if (x == y) {
true
} else if (absX < Double.MinPositiveValue || absY < Double.MinPositiveValue) {
throw new TestFailedException(
s"$x or $y is extremely close to zero, so the relative tolerance is meaningless.", 0)
} else {
diff < eps * math.min(absX, absY)
}
}
/**
* Private helper function for comparing two values using absolute tolerance.
*/
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
math.abs(x - y) < eps
}
case class CompareDoubleRightSide(
fun: (Double, Double, Double) => Boolean, y: Double, eps: Double, method: String)
/**
* Implicit class for comparing two double values using relative tolerance or absolute tolerance.
*/
implicit class DoubleWithAlmostEquals(val x: Double) {
// An improved version of AlmostEquals would always divide by the larger number.
// This will avoid the problem of diving by zero.
def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = {
if(x == y) {
true
} else if(math.abs(x) > math.abs(y)) {
math.abs(x - y) / math.abs(x) < epsilon
} else {
math.abs(x - y) / math.abs(y) < epsilon
/**
* When the difference of two values are within eps, returns true; otherwise, returns false.
*/
def ~=(r: CompareDoubleRightSide): Boolean = r.fun(x, r.y, r.eps)
/**
* When the difference of two values are within eps, returns false; otherwise, returns true.
*/
def !~=(r: CompareDoubleRightSide): Boolean = !r.fun(x, r.y, r.eps)
/**
* Throws exception when the difference of two values are NOT within eps;
* otherwise, returns true.
*/
def ~==(r: CompareDoubleRightSide): Boolean = {
if (!r.fun(x, r.y, r.eps)) {
throw new TestFailedException(
s"Expected $x and ${r.y} to be within ${r.eps}${r.method}.", 0)
}
true
}
/**
* Throws exception when the difference of two values are within eps; otherwise, returns true.
*/
def !~==(r: CompareDoubleRightSide): Boolean = {
if (r.fun(x, r.y, r.eps)) {
throw new TestFailedException(
s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method}.", 0)
}
true
}
/**
* Comparison using absolute tolerance.
*/
def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison,
x, eps, ABS_TOL_MSG)
/**
* Comparison using relative tolerance.
*/
def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison,
x, eps, REL_TOL_MSG)
override def toString = x.toString
}
case class CompareVectorRightSide(
fun: (Vector, Vector, Double) => Boolean, y: Vector, eps: Double, method: String)
/**
* Implicit class for comparing two vectors using relative tolerance or absolute tolerance.
*/
implicit class VectorWithAlmostEquals(val x: Vector) {
def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = {
x.toArray.corresponds(y.toArray) {
_.almostEquals(_, epsilon)
/**
* When the difference of two vectors are within eps, returns true; otherwise, returns false.
*/
def ~=(r: CompareVectorRightSide): Boolean = r.fun(x, r.y, r.eps)
/**
* When the difference of two vectors are within eps, returns false; otherwise, returns true.
*/
def !~=(r: CompareVectorRightSide): Boolean = !r.fun(x, r.y, r.eps)
/**
* Throws exception when the difference of two vectors are NOT within eps;
* otherwise, returns true.
*/
def ~==(r: CompareVectorRightSide): Boolean = {
if (!r.fun(x, r.y, r.eps)) {
throw new TestFailedException(
s"Expected $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0)
}
true
}
/**
* Throws exception when the difference of two vectors are within eps; otherwise, returns true.
*/
def !~==(r: CompareVectorRightSide): Boolean = {
if (r.fun(x, r.y, r.eps)) {
throw new TestFailedException(
s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0)
}
true
}
/**
* Comparison using absolute tolerance.
*/
def absTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide(
(x: Vector, y: Vector, eps: Double) => {
x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps)
}, x, eps, ABS_TOL_MSG)
/**
* Comparison using relative tolerance. Note that comparing against sparse vector
* with elements having value of zero will raise exception because it involves with
* comparing against zero.
*/
def relTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide(
(x: Vector, y: Vector, eps: Double) => {
x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps)
}, x, eps, REL_TOL_MSG)
override def toString = x.toString
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.util
import org.apache.spark.mllib.linalg.Vectors
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.scalatest.exceptions.TestFailedException
class TestingUtilsSuite extends FunSuite {
test("Comparing doubles using relative error.") {
assert(23.1 ~== 23.52 relTol 0.02)
assert(23.1 ~== 22.74 relTol 0.02)
assert(23.1 ~= 23.52 relTol 0.02)
assert(23.1 ~= 22.74 relTol 0.02)
assert(!(23.1 !~= 23.52 relTol 0.02))
assert(!(23.1 !~= 22.74 relTol 0.02))
// Should throw exception with message when test fails.
intercept[TestFailedException](23.1 !~== 23.52 relTol 0.02)
intercept[TestFailedException](23.1 !~== 22.74 relTol 0.02)
intercept[TestFailedException](23.1 ~== 23.63 relTol 0.02)
intercept[TestFailedException](23.1 ~== 22.34 relTol 0.02)
assert(23.1 !~== 23.63 relTol 0.02)
assert(23.1 !~== 22.34 relTol 0.02)
assert(23.1 !~= 23.63 relTol 0.02)
assert(23.1 !~= 22.34 relTol 0.02)
assert(!(23.1 ~= 23.63 relTol 0.02))
assert(!(23.1 ~= 22.34 relTol 0.02))
// Comparing against zero should fail the test and throw exception with message
// saying that the relative error is meaningless in this situation.
intercept[TestFailedException](0.1 ~== 0.0 relTol 0.032)
intercept[TestFailedException](0.1 ~= 0.0 relTol 0.032)
intercept[TestFailedException](0.1 !~== 0.0 relTol 0.032)
intercept[TestFailedException](0.1 !~= 0.0 relTol 0.032)
intercept[TestFailedException](0.0 ~== 0.1 relTol 0.032)
intercept[TestFailedException](0.0 ~= 0.1 relTol 0.032)
intercept[TestFailedException](0.0 !~== 0.1 relTol 0.032)
intercept[TestFailedException](0.0 !~= 0.1 relTol 0.032)
// Comparisons of numbers very close to zero.
assert(10 * Double.MinPositiveValue ~== 9.5 * Double.MinPositiveValue relTol 0.01)
assert(10 * Double.MinPositiveValue !~== 11 * Double.MinPositiveValue relTol 0.01)
assert(-Double.MinPositiveValue ~== 1.18 * -Double.MinPositiveValue relTol 0.012)
assert(-Double.MinPositiveValue ~== 1.38 * -Double.MinPositiveValue relTol 0.012)
}
test("Comparing doubles using absolute error.") {
assert(17.8 ~== 17.99 absTol 0.2)
assert(17.8 ~== 17.61 absTol 0.2)
assert(17.8 ~= 17.99 absTol 0.2)
assert(17.8 ~= 17.61 absTol 0.2)
assert(!(17.8 !~= 17.99 absTol 0.2))
assert(!(17.8 !~= 17.61 absTol 0.2))
// Should throw exception with message when test fails.
intercept[TestFailedException](17.8 !~== 17.99 absTol 0.2)
intercept[TestFailedException](17.8 !~== 17.61 absTol 0.2)
intercept[TestFailedException](17.8 ~== 18.01 absTol 0.2)
intercept[TestFailedException](17.8 ~== 17.59 absTol 0.2)
assert(17.8 !~== 18.01 absTol 0.2)
assert(17.8 !~== 17.59 absTol 0.2)
assert(17.8 !~= 18.01 absTol 0.2)
assert(17.8 !~= 17.59 absTol 0.2)
assert(!(17.8 ~= 18.01 absTol 0.2))
assert(!(17.8 ~= 17.59 absTol 0.2))
// Comparisons of numbers very close to zero, and both side of zeros
assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
}
test("Comparing vectors using relative error.") {
//Comparisons of two dense vectors
assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01)
assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01)
assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01)
assert(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01)
assert(!(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01))
assert(!(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01))
// Should throw exception with message when test fails.
intercept[TestFailedException](
Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01)
intercept[TestFailedException](
Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01)
// Comparing against zero should fail the test and throw exception with message
// saying that the relative error is meaningless in this situation.
intercept[TestFailedException](
Vectors.dense(Array(3.1, 0.01)) ~== Vectors.dense(Array(3.13, 0.0)) relTol 0.01)
intercept[TestFailedException](
Vectors.dense(Array(3.1, 0.01)) ~== Vectors.sparse(2, Array(0), Array(3.13)) relTol 0.01)
// Comparisons of two sparse vectors
assert(Vectors.dense(Array(3.1, 3.5)) ~==
Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01)
assert(Vectors.dense(Array(3.1, 3.5)) !~==
Vectors.sparse(2, Array(0, 1), Array(3.135, 3.534)) relTol 0.01)
}
test("Comparing vectors using absolute error.") {
//Comparisons of two dense vectors
assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~==
Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)
assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~==
Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)
assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~=
Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)
assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~=
Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)
assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) !~=
Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6))
assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) ~=
Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6))
// Should throw exception with message when test fails.
intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) !~==
Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)
intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) ~==
Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)
// Comparisons of two sparse vectors
assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~==
Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) absTol 1E-6)
assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) ~==
Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6)
assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~==
Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) absTol 1E-6)
assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) !~==
Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6)
// Comparisons of a dense vector and a sparse vector
assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~==
Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) absTol 1E-6)
assert(Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) ~==
Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6)
assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~==
Vectors.dense(Array(3.1, 1E-3, 2.4)) absTol 1E-6)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册