提交 0dd06485 编写于 作者: Y Yanbo Liang 提交者: Joseph K. Bradley

[SPARK-13615][ML] GeneralizedLinearRegression supports save/load

## What changes were proposed in this pull request?
```GeneralizedLinearRegression``` supports ```save/load```.
cc mengxr
## How was this patch tested?
unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #11465 from yanboliang/spark-13615.
上级 cad29a40
......@@ -18,6 +18,7 @@
package org.apache.spark.ml.regression
import breeze.stats.distributions.{Gaussian => GD}
import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.{Experimental, Since}
......@@ -26,7 +27,7 @@ import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.optim._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
......@@ -106,7 +107,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String)
extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
with GeneralizedLinearRegressionBase with Logging {
with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging {
import GeneralizedLinearRegression._
......@@ -236,10 +237,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
@Since("2.0.0")
private[ml] object GeneralizedLinearRegression {
object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLinearRegression] {
@Since("2.0.0")
override def load(path: String): GeneralizedLinearRegression = super.load(path)
/** Set of family and link pairs that GeneralizedLinearRegression supports. */
lazy val supportedFamilyAndLinkPairs = Set(
private[ml] lazy val supportedFamilyAndLinkPairs = Set(
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
......@@ -247,12 +251,12 @@ private[ml] object GeneralizedLinearRegression {
)
/** Set of family names that GeneralizedLinearRegression supports. */
lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
/** Set of link names that GeneralizedLinearRegression supports. */
lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
val epsilon: Double = 1E-16
private[ml] val epsilon: Double = 1E-16
/**
* Wrapper of family and link combination used in the model.
......@@ -552,7 +556,7 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0") val coefficients: Vector,
@Since("2.0.0") val intercept: Double)
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
with GeneralizedLinearRegressionBase {
with GeneralizedLinearRegressionBase with MLWritable {
import GeneralizedLinearRegression._
......@@ -574,4 +578,58 @@ class GeneralizedLinearRegressionModel private[ml] (
copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
.setParent(parent)
}
@Since("2.0.0")
override def write: MLWriter =
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
}
@Since("2.0.0")
object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegressionModel] {
@Since("2.0.0")
override def read: MLReader[GeneralizedLinearRegressionModel] =
new GeneralizedLinearRegressionModelReader
@Since("2.0.0")
override def load(path: String): GeneralizedLinearRegressionModel = super.load(path)
/** [[MLWriter]] instance for [[GeneralizedLinearRegressionModel]] */
private[GeneralizedLinearRegressionModel]
class GeneralizedLinearRegressionModelWriter(instance: GeneralizedLinearRegressionModel)
extends MLWriter with Logging {
private case class Data(intercept: Double, coefficients: Vector)
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class GeneralizedLinearRegressionModelReader
extends MLReader[GeneralizedLinearRegressionModel] {
/** Checked against metadata when loading model */
private val className = classOf[GeneralizedLinearRegressionModel].getName
override def load(path: String): GeneralizedLinearRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("intercept", "coefficients").head()
val intercept = data.getDouble(0)
val coefficients = data.getAs[Vector](1)
val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
......@@ -21,7 +21,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors}
import org.apache.spark.mllib.random._
......@@ -30,7 +30,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
class GeneralizedLinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
private val seed: Int = 42
@transient var datasetGaussianIdentity: DataFrame = _
......@@ -464,10 +465,37 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark
}
}
}
test("read/write") {
def checkModelData(
model: GeneralizedLinearRegressionModel,
model2: GeneralizedLinearRegressionModel): Unit = {
assert(model.intercept === model2.intercept)
assert(model.coefficients.toArray === model2.coefficients.toArray)
}
val glr = new GeneralizedLinearRegression()
testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
}
}
object GeneralizedLinearRegressionSuite {
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"family" -> "poisson",
"link" -> "log",
"fitIntercept" -> true,
"maxIter" -> 2, // intentionally small
"tol" -> 0.8,
"regParam" -> 0.01,
"predictionCol" -> "myPrediction")
def generateGeneralizedLinearRegressionInput(
intercept: Double,
coefficients: Array[Double],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册