提交 615cf42b 编写于 作者: C Chiwan Park

[FLINK-2984] [ml] Extend libSVM file format support

This closes #1504.
上级 be055b7a
......@@ -18,12 +18,13 @@
package org.apache.flink.ml
import org.apache.flink.api.common.functions.RichMapFunction
import org.apache.flink.api.common.functions.{RichFlatMapFunction, RichMapFunction}
import org.apache.flink.api.java.operators.DataSink
import org.apache.flink.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.ml.math.SparseVector
import org.apache.flink.util.Collector
/** Convenience functions for machine learning tasks
*
......@@ -53,17 +54,21 @@ object MLUtils {
* file
*/
def readLibSVM(env: ExecutionEnvironment, filePath: String): DataSet[LabeledVector] = {
val labelCOODS = env.readTextFile(filePath).flatMap {
line =>
// remove all comments which start with a '#'
val commentFreeLine = line.takeWhile(_ != '#').trim
if(commentFreeLine.nonEmpty) {
val splits = commentFreeLine.split(' ')
val label = splits.head.toDouble
val sparseFeatures = splits.tail
val coos = sparseFeatures.map {
str =>
val labelCOODS = env.readTextFile(filePath).flatMap(
new RichFlatMapFunction[String, (Double, Array[(Int, Double)])] {
val splitPattern = "\\s+".r
override def flatMap(
line: String,
out: Collector[(Double, Array[(Int, Double)])]
): Unit = {
val commentFreeLine = line.takeWhile(_ != '#').trim
if (commentFreeLine.nonEmpty) {
val splits = splitPattern.split(commentFreeLine)
val label = splits.head.toDouble
val sparseFeatures = splits.tail
val coos = sparseFeatures.flatMap { str =>
val pair = str.split(':')
require(pair.length == 2, "Each feature entry has to have the form <feature>:<value>")
......@@ -71,14 +76,13 @@ object MLUtils {
val index = pair(0).toInt - 1
val value = pair(1).toDouble
(index, value)
}
Some((index, value))
}
Some((label, coos))
} else {
None
out.collect((label, coos))
}
}
}
})
// Calculate maximum dimension of vectors
val dimensionDS = labelCOODS.map {
......
......@@ -40,9 +40,9 @@ class MLUtilsSuite extends FlatSpec with Matchers with FlinkTestBase {
val content =
"""
|1 2:10.0 4:4.5 8:4.2 # foo
|1 2:10.0 4:4.5 8:4.2 # foo
|-1 1:9.0 4:-4.5 7:2.4 # bar
|0.4 3:1.0 8:-5.6 10:1.0
|0.4 3:1.0 8:-5.6 10:1.0
|-42.1 2:2.0 4:-6.1 3:5.1 # svm
""".stripMargin
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册