提交 7fe9273f 编写于 作者: M mbalassi 提交者: Aljoscha Krettek

Added LinearRegression scala example. Removed old BGD example.

上级 dbc680f2
......@@ -16,7 +16,6 @@
* limitations under the License.
*/
package org.apache.flink.example.java.ml.util;
import org.apache.flink.api.java.DataSet;
......@@ -24,45 +23,50 @@ import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.example.java.ml.LinearRegression.Data;
import org.apache.flink.example.java.ml.LinearRegression.Params;
import java.util.LinkedList;
import java.util.List;
/**
* Provides the default data sets used for the Linear Regression example program.
* The default data sets are used, if no parameters are given to the program.
*
* Provides the default data sets used for the Linear Regression example
* program. The default data sets are used, if no parameters are given to the
* program.
*/
public class LinearRegressionData{
public class LinearRegressionData {
// We have the data as object arrays so that we can also generate Scala Data
// Sources from it.
public static final Object[][] PARAMS = new Object[][] { new Object[] {
0.0, 0.0 } };
public static DataSet<Params> getDefaultParamsDataSet(ExecutionEnvironment env){
public static final Object[][] DATA = new Object[][] {
new Object[] { 0.5, 1.0 }, new Object[] { 1.0, 2.0 },
new Object[] { 2.0, 4.0 }, new Object[] { 3.0, 6.0 },
new Object[] { 4.0, 8.0 }, new Object[] { 5.0, 10.0 },
new Object[] { 6.0, 12.0 }, new Object[] { 7.0, 14.0 },
new Object[] { 8.0, 16.0 }, new Object[] { 9.0, 18.0 },
new Object[] { 10.0, 20.0 }, new Object[] { -0.08, -0.16 },
new Object[] { 0.13, 0.26 }, new Object[] { -1.17, -2.35 },
new Object[] { 1.72, 3.45 }, new Object[] { 1.70, 3.41 },
new Object[] { 1.20, 2.41 }, new Object[] { -0.59, -1.18 },
new Object[] { 0.28, 0.57 }, new Object[] { 1.65, 3.30 },
new Object[] { -0.55, -1.08 } };
return env.fromElements(
new Params(0.0,0.0)
);
public static DataSet<Params> getDefaultParamsDataSet(
ExecutionEnvironment env) {
List<Params> paramsList = new LinkedList<Params>();
for (Object[] params : PARAMS) {
paramsList.add(new Params((Double) params[0], (Double) params[1]));
}
return env.fromCollection(paramsList);
}
public static DataSet<Data> getDefaultDataDataSet(ExecutionEnvironment env){
public static DataSet<Data> getDefaultDataDataSet(ExecutionEnvironment env) {
return env.fromElements(
new Data(0.5,1.0),
new Data(1.0,2.0),
new Data(2.0,4.0),
new Data(3.0,6.0),
new Data(4.0,8.0),
new Data(5.0,10.0),
new Data(6.0,12.0),
new Data(7.0,14.0),
new Data(8.0,16.0),
new Data(9.0,18.0),
new Data(10.0,20.0),
new Data(-0.08,-0.16),
new Data(0.13,0.26),
new Data(-1.17,-2.35),
new Data(1.72,3.45),
new Data(1.70,3.41),
new Data(1.20,2.41),
new Data(-0.59,-1.18),
new Data(0.28,0.57),
new Data(1.65,3.30),
new Data(-0.55,-1.08)
);
List<Data> dataList = new LinkedList<Data>();
for (Object[] data : DATA) {
dataList.add(new Data((Double) data[0], (Double) data[1]));
}
return env.fromCollection(dataList);
}
}
\ No newline at end of file
///**
// * Licensed to the Apache Software Foundation (ASF) under one
// * or more contributor license agreements. See the NOTICE file
// * distributed with this work for additional information
// * regarding copyright ownership. The ASF licenses this file
// * to you under the Apache License, Version 2.0 (the
// * "License"); you may not use this file except in compliance
// * with the License. You may obtain a copy of the License at
// *
// * http://www.apache.org/licenses/LICENSE-2.0
// *
// * Unless required by applicable law or agreed to in writing, software
// * distributed under the License is distributed on an "AS IS" BASIS,
// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// * See the License for the specific language governing permissions and
// * limitations under the License.
// */
//
//
//package org.apache.flink.examples.scala.datamining
//
//import scala.math._
//import org.apache.flink.api.scala._
//import org.apache.flink.api.scala.operators._
//
//
//abstract class BatchGradientDescent(eps: Double, eta: Double, lambda: Double, examplesInput: String, weightsInput: String, weightsOutput: String) extends Serializable {
// def computeGradient(example: Array[Double], weight: Array[Double]): (Double, Array[Double])
//
// def updateWeight = (prev: (Int, Array[Double], Double), vg: ValueAndGradient) => {
// val (id, wOld, eta) = prev
// val ValueAndGradient(_, lossSum, gradSum) = vg
//
// val delta = lossSum + lambda * wOld.norm
// val wNew = (wOld + (gradSum * eta)) * (1 - eta * lambda)
// (id, delta, wNew, eta * 0.9)
// }
//
// class WeightVector(vector: Array[Double]) {
// def +(that: Array[Double]): Array[Double] = (vector zip that) map { case (x1, x2) => x1 + x2 }
// def -(that: Array[Double]): Array[Double] = (vector zip that) map { case (x1, x2) => x1 - x2 }
// def *(x: Double): Array[Double] = vector map { x * _ }
// def norm: Double = sqrt(vector map { x => x * x } reduce { _ + _ })
// }
//
// implicit def array2WeightVector(vector: Array[Double]): WeightVector = new WeightVector(vector)
//
// case class ValueAndGradient(id: Int, value: Double, gradient: Array[Double]) {
// def this(id: Int, vg: (Double, Array[Double])) = this(id, vg._1, vg._2)
// def +(that: ValueAndGradient) = ValueAndGradient(id, value + that.value, gradient + that.gradient)
// }
//
// def readVector = (line: String) => {
// val Seq(id, vector @ _*) = line.split(',').toSeq
// id.toInt -> (vector map { _.toDouble } toArray)
// }
//
// def formatOutput = (id: Int, vector: Array[Double]) => "%s,%s".format(id, vector.mkString(","))
//
// def getPlan() = {
//
// val examples = DataSource(examplesInput, DelimitedInputFormat(readVector))
// val weights = DataSource(weightsInput, DelimitedInputFormat(readVector))
//
// def gradientDescent = (s: DataSetOLD[(Int, Array[Double])], ws: DataSetOLD[(Int, Array[Double], Double)]) => {
//
// val lossesAndGradients = ws cross examples map { (w, ex) => new ValueAndGradient(w._1, computeGradient(ex._2, w._2)) }
// val lossAndGradientSums = lossesAndGradients groupBy { _.id } reduce (_ + _)
// val newWeights = ws join lossAndGradientSums where { _._1 } isEqualTo { _.id } map updateWeight
//
// val s1 = newWeights map { case (wId, _, wNew, _) => (wId, wNew) } // updated solution elements
// val ws1 = newWeights filter { case (_, delta, _, _) => delta > eps } map { case (wId, _, wNew, etaNew) => (wId, wNew, etaNew) } // new workset
//
// (s1, ws1)
// }
//
// val newWeights = weights.iterateWithDelta(weights.map { case (id, w) => (id, w, eta) }, {_._1}, gradientDescent, 10)
//
// val output = newWeights.write(weightsOutput, DelimitedOutputFormat(formatOutput.tupled))
// new ScalaPlan(Seq(output), "Batch Gradient Descent")
// }
//}
//
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.examples.scala.ml
import java.io.Serializable
import org.apache.flink.api.common.functions._
import org.apache.flink.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.example.java.clustering.util.KMeansData
import org.apache.flink.example.java.ml.util.LinearRegressionData
import scala.collection.JavaConverters._
/**
* This example implements a basic Linear Regression to solve the y = theta0 + theta1*x problem using batch gradient descent algorithm.
*
* <p>
* Linear Regression with BGD(batch gradient descent) algorithm is an iterative clustering algorithm and works as follows:<br>
* Giving a data set and target set, the BGD try to find out the best parameters for the data set to fit the target set.
* In each iteration, the algorithm computes the gradient of the cost function and use it to update all the parameters.
* The algorithm terminates after a fixed number of iterations (as in this implementation)
* With enough iteration, the algorithm can minimize the cost function and find the best parameters
* This is the Wikipedia entry for the <a href = "http://en.wikipedia.org/wiki/Linear_regression">Linear regression</a> and <a href = "http://en.wikipedia.org/wiki/Gradient_descent">Gradient descent algorithm</a>.
*
* <p>
* This implementation works on one-dimensional data. And find the two-dimensional theta.<br>
* It find the best Theta parameter to fit the target.
*
* <p>
* Input files are plain text files and must be formatted as follows:
* <ul>
* <li>Data points are represented as two double values separated by a blank character. The first one represent the X(the training data) and the second represent the Y(target).
* Data points are separated by newline characters.<br>
* For example <code>"-0.02 -0.04\n5.3 10.6\n"</code> gives two data points (x=-0.02, y=-0.04) and (x=5.3, y=10.6).
* </ul>
*
* <p>
* This example shows how to use:
* <ul>
* <li> Bulk iterations
* <li> Broadcast variables in bulk iterations
* <li> Custom Java objects (PoJos)
* </ul>
*/
object LinearRegression {
// *************************************************************************
// PROGRAM
// *************************************************************************
def main(args: Array[String]) {
if (!parseParameters(args)) {
return
}
val env = ExecutionEnvironment.getExecutionEnvironment
val data: DataSet[Data] = getDataSet(env)
val parameters: DataSet[Params] = getParamsDataSet(env)
val result = parameters.iterate(numIterations) { currentParameters =>
val newParameters = data
.map(new SubUpdate).withBroadcastSet(currentParameters, "parameters")
.reduce { (val1, val2) =>
val new_theta0: Double = val1._1.getTheta0 + val2._1.getTheta0
val new_theta1: Double = val1._1.getTheta1 + val2._1.getTheta1
val result: Params = new Params(new_theta0, new_theta1)
(result, val1._2 + val2._2)
}
.map { x => x._1.div(x._2) }
newParameters
}
if (fileOutput) {
result.writeAsText(outputPath)
}
else {
result.print
}
env.execute("Scala Linear Regression example")
}
// *************************************************************************
// DATA TYPES
// *************************************************************************
/**
* A simple data sample, x means the input, and y means the target.
*/
class Data extends Serializable {
def this(x: Double, y: Double) {
this()
this.x = x
this.y = y
}
override def toString: String = {
"(" + x + "|" + y + ")"
}
var x: Double = .0
var y: Double = .0
}
/**
* A set of parameters -- theta0, theta1.
*/
class Params extends Serializable {
def this(x0: Double, x1: Double) {
this()
this.theta0 = x0
this.theta1 = x1
}
override def toString: String = {
theta0 + " " + theta1
}
def getTheta0: Double = {
theta0
}
def getTheta1: Double = {
theta1
}
def setTheta0(theta0: Double) {
this.theta0 = theta0
}
def setTheta1(theta1: Double) {
this.theta1 = theta1
}
def div(a: Integer): Params = {
this.theta0 = theta0 / a
this.theta1 = theta1 / a
return this
}
private var theta0: Double = .0
private var theta1: Double = .0
}
// *************************************************************************
// USER FUNCTIONS
// *************************************************************************
/**
* Compute a single BGD type update for every parameters.
*/
class SubUpdate extends RichMapFunction[Data, Tuple2[Params, Integer]] {
private var parameters: Traversable[Params] = null
var parameter: Params = null
private var count: Int = 1
/** Reads the parameters from a broadcast variable into a collection. */
override def open(parameters: Configuration) {
this.parameters = getRuntimeContext.getBroadcastVariable[Params]("parameters").asScala
}
def map(in: Data): Tuple2[Params, Integer] = {
for (p <- parameters) {
this.parameter = p
}
val theta_0: Double = parameter.getTheta0 - 0.01 * ((parameter.getTheta0 + (parameter.getTheta1 * in.x)) - in.y)
val theta_1: Double = parameter.getTheta1 - 0.01 * (((parameter.getTheta0 + (parameter.getTheta1 * in.x)) - in.y) * in.x)
new Tuple2[Params, Integer](new Params(theta_0, theta_1), count)
}
}
// *************************************************************************
// UTIL METHODS
// *************************************************************************
private var fileOutput: Boolean = false
private var dataPath: String = null
private var outputPath: String = null
private var numIterations: Int = 10
private def parseParameters(programArguments: Array[String]): Boolean = {
if (programArguments.length > 0) {
fileOutput = true
if (programArguments.length == 3) {
dataPath = programArguments(0)
outputPath = programArguments(1)
numIterations = Integer.parseInt(programArguments(2))
}
else {
System.err.println("Usage: LinearRegression <data path> <result path> <num iterations>")
false
}
}
else {
System.out.println("Executing Linear Regression example with default parameters and built-in default data.")
System.out.println(" Provide parameters to read input data from files.")
System.out.println(" See the documentation for the correct format of input files.")
System.out.println(" We provide a data generator to create synthetic input files for this program.")
System.out.println(" Usage: LinearRegression <data path> <result path> <num iterations>")
}
true
}
private def getDataSet(env: ExecutionEnvironment): DataSet[Data] = {
if (fileOutput) {
env.readCsvFile[(Double, Double)](
dataPath,
fieldDelimiter = ' ',
includedFields = Array(0, 1))
.map { t => new Data(t._1, t._2) }
}
else {
val data = LinearRegressionData.DATA map {
case Array(x, y) => new Data(x.asInstanceOf[Double], y.asInstanceOf[Double])
}
env.fromCollection(data)
}
}
private def getParamsDataSet(env: ExecutionEnvironment): DataSet[Params] = {
val params = LinearRegressionData.PARAMS map {
case Array(x, y) => new Params(x.asInstanceOf[Double], y.asInstanceOf[Double])
}
env.fromCollection(params)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册