IncrementalOLS.java 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/***********************************************************************************************************************
 *
 * Copyright (C) 2010-2014 by the Stratosphere project (http://stratosphere.eu)
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 *
 **********************************************************************************************************************/
15 16 17
package eu.stratosphere.streaming.examples.ml;

import java.net.InetSocketAddress;
18
import java.util.Random;
19

20 21
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.math.stat.regression.OLSMultipleLinearRegression;
22 23 24 25
import org.apache.log4j.Level;

import eu.stratosphere.api.java.tuple.Tuple;
import eu.stratosphere.api.java.tuple.Tuple1;
26
import eu.stratosphere.api.java.tuple.Tuple2;
27 28 29 30 31 32 33 34 35 36 37
import eu.stratosphere.client.minicluster.NepheleMiniCluster;
import eu.stratosphere.client.program.Client;
import eu.stratosphere.configuration.Configuration;
import eu.stratosphere.nephele.jobgraph.JobGraph;
import eu.stratosphere.streaming.api.JobGraphBuilder;
import eu.stratosphere.streaming.api.invokable.UserSinkInvokable;
import eu.stratosphere.streaming.api.invokable.UserSourceInvokable;
import eu.stratosphere.streaming.api.invokable.UserTaskInvokable;
import eu.stratosphere.streaming.api.streamrecord.StreamRecord;
import eu.stratosphere.streaming.util.LogUtils;

38
public class IncrementalOLS {
39 40 41

	public static class NewDataSource extends UserSourceInvokable {

42 43 44
		StreamRecord record = new StreamRecord(2, 1);

		Random rnd = new Random();
45 46 47

		@Override
		public void invoke() throws Exception {
48 49
			record.initRecords();
			for (int j = 0; j < 100; j++) {
M
Márton Balassi 已提交
50
				// pull new record from data source
51 52 53 54 55 56 57
				record.setTuple(getNewData());
				emit(record);
			}

		}

		private Tuple getNewData() throws InterruptedException {
58 59 60

			return new Tuple2<Boolean, Double[]>(false, new Double[] { rnd.nextDouble() * 3,
					rnd.nextDouble() * 5 });
61 62 63
		}
	}

64
	public static class TrainingDataSource extends UserSourceInvokable {
65

66 67 68
		private final int BATCH_SIZE = 10;

		StreamRecord record = new StreamRecord(2, BATCH_SIZE);
69

70
		Random rnd = new Random();
71 72 73 74

		@Override
		public void invoke() throws Exception {

75
			record.initRecords();
76

77
			for (int j = 0; j < 1000; j++) {
78
				for (int i = 0; i < BATCH_SIZE; i++) {
79
					record.setTuple(i, getTrainingData());
80 81 82 83 84 85
				}
				emit(record);
			}

		}

86
		private Tuple getTrainingData() throws InterruptedException {
87 88 89

			return new Tuple2<Double, Double[]>(rnd.nextDouble() * 10, new Double[] {
					rnd.nextDouble() * 3, rnd.nextDouble() * 5 });
90

91 92 93
		}
	}

94 95 96 97 98 99 100 101 102
	public static class PartialModelBuilder extends UserTaskInvokable {

		@Override
		public void invoke(StreamRecord record) throws Exception {
			emit(buildPartialModel(record));
		}

		protected StreamRecord buildPartialModel(StreamRecord record) {

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
			Integer numOfTuples = record.getNumOfTuples();
			Integer numOfFeatures = ((Double[]) record.getField(1)).length;

			double[][] x = new double[numOfTuples][numOfFeatures];
			double[] y = new double[numOfTuples];

			for (int i = 0; i < numOfTuples; i++) {

				Tuple t = record.getTuple(i);
				Double[] x_i = t.getField(1);
				y[i] = t.getField(0);
				for (int j = 0; j < numOfFeatures; j++) {
					x[i][j] = x_i[j];
				}
			}

			OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
			ols.newSampleData(y, x);

			return new StreamRecord(new Tuple2<Boolean, Double[]>(true,
					(Double[]) ArrayUtils.toObject(ols.estimateRegressionParameters())));
		}
125 126 127 128
	}

	public static class Predictor extends UserTaskInvokable {

129 130
		// StreamRecord batchModel = null;
		Double[] partialModel = new Double[] { 0.0, 0.0 };
131 132 133

		@Override
		public void invoke(StreamRecord record) throws Exception {
134
			if (isModel(record)) {
135 136
				partialModel = (Double[]) record.getField(1);
				// batchModel = getBatchModel();
137 138 139 140 141 142
			} else {
				emit(predict(record));
			}

		}

143 144 145
		// protected StreamRecord getBatchModel() {
		// return new StreamRecord(new Tuple1<Integer>(1));
		// }
146 147

		protected boolean isModel(StreamRecord record) {
148
			return record.getBoolean(0);
149 150 151
		}

		protected StreamRecord predict(StreamRecord record) {
152 153 154 155 156 157 158 159
			Double[] x = (Double[]) record.getField(1);

			Double prediction = 0.0;
			for (int i = 0; i < x.length; i++) {
				prediction = prediction + x[i] * partialModel[i];
			}

			return new StreamRecord(new Tuple1<Double>(prediction));
160 161 162 163 164 165 166 167
		}

	}

	public static class Sink extends UserSinkInvokable {

		@Override
		public void invoke(StreamRecord record) throws Exception {
168
			System.out.println(record.getTuple());
169 170 171 172
		}
	}

	private static JobGraph getJobGraph() throws Exception {
173
		JobGraphBuilder graphBuilder = new JobGraphBuilder("IncrementalOLS");
174

175 176 177 178 179 180 181 182 183 184 185
		graphBuilder.setSource("NewData", NewDataSource.class, 1, 1);
		graphBuilder.setSource("TrainingData", TrainingDataSource.class, 1, 1);
		graphBuilder.setTask("PartialModelBuilder", PartialModelBuilder.class, 1, 1);
		graphBuilder.setTask("Predictor", Predictor.class, 1, 1);
		graphBuilder.setSink("Sink", Sink.class, 1, 1);

		graphBuilder.shuffleConnect("TrainingData", "PartialModelBuilder");
		graphBuilder.shuffleConnect("NewData", "Predictor");
		graphBuilder.broadcastConnect("PartialModelBuilder", "Predictor");
		graphBuilder.shuffleConnect("Predictor", "Sink");

186 187 188 189 190 191
		return graphBuilder.getJobGraph();
	}

	public static void main(String[] args) {

		// set logging parameters for local run
192
		LogUtils.initializeDefaultConsoleLogger(Level.INFO, Level.INFO);
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221

		try {

			// generate JobGraph
			JobGraph jG = getJobGraph();
			Configuration configuration = jG.getJobConfiguration();

			if (args.length == 0 || args[0].equals("local")) {
				System.out.println("Running in Local mode");
				// start local cluster and submit JobGraph
				NepheleMiniCluster exec = new NepheleMiniCluster();
				exec.start();

				Client client = new Client(new InetSocketAddress("localhost", 6498), configuration);

				client.run(jG, true);

				exec.stop();
			} else if (args[0].equals("cluster")) {
				System.out.println("Running in Cluster mode");
				// submit JobGraph to the running cluster
				Client client = new Client(new InetSocketAddress("dell150", 6123), configuration);
				client.run(jG, true);
			}

		} catch (Exception e) {
			System.out.println(e);
		}
	}
M
Márton Balassi 已提交
222
}