提交 9219af7b 编写于 作者: T Till Rohrmann

[FLINK-1718] [ml] Adds sparse matrix and sparse vector types

上级 c6358024
......@@ -41,9 +41,9 @@
</dependency>
<dependency>
<groupId>com.github.fommil.netlib</groupId>
<artifactId>core</artifactId>
<version>1.1.2</version>
<groupId>org.scalanlp</groupId>
<artifactId>breeze_2.10</artifactId>
<version>0.11.1</version>
</dependency>
<dependency>
......
......@@ -24,13 +24,15 @@ package org.apache.flink.ml.math
*
* @param numRows Number of rows
* @param numCols Number of columns
* @param values Array of matrix elements in column major order
* @param data Array of matrix elements in column major order
*/
case class DenseMatrix(val numRows: Int,
val numCols: Int,
val values: Array[Double]) extends Matrix {
val data: Array[Double]) extends Matrix {
require(numRows * numCols == values.length, s"The number of values ${values.length} does " +
import DenseMatrix._
require(numRows * numCols == data.length, s"The number of values ${data.length} does " +
s"not correspond to its dimensions ($numRows, $numCols).")
/**
......@@ -41,32 +43,129 @@ case class DenseMatrix(val numRows: Int,
* @return matrix entry at (row, col)
*/
override def apply(row: Int, col: Int): Double = {
require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
val index = locate(row, col)
val index = col * numRows + row
values(index)
data(index)
}
override def toString: String = {
s"DenseMatrix($numRows, $numCols, ${values.mkString(", ")})"
val result = StringBuilder.newBuilder
result.append(s"DenseMatrix($numRows, $numCols)\n")
val linewidth = LINE_WIDTH
val columnsFieldWidths = for(row <- 0 until math.min(numRows, MAX_ROWS)) yield {
var column = 0
var maxFieldWidth = 0
while(column * maxFieldWidth < linewidth && column < numCols) {
val fieldWidth = printEntry(row, column).length + 2
if(fieldWidth > maxFieldWidth) {
maxFieldWidth = fieldWidth
}
if(column * maxFieldWidth < linewidth) {
column += 1
}
}
(column, maxFieldWidth)
}
val (columns, fieldWidths) = columnsFieldWidths.unzip
val maxColumns = columns.min
val fieldWidth = fieldWidths.max
for(row <- 0 until math.min(numRows, MAX_ROWS)) {
for(col <- 0 until maxColumns) {
val str = printEntry(row, col)
result.append(" " * (fieldWidth - str.length) + str)
}
if(maxColumns < numCols) {
result.append("...")
}
result.append("\n")
}
if(numRows > MAX_ROWS) {
result.append("...\n")
}
result.toString()
}
override def equals(obj: Any): Boolean = {
obj match {
case dense: DenseMatrix =>
numRows == dense.numRows && numCols == dense.numCols && values.zip(dense.values).forall {
case (a, b) => a == b
}
case _ => false
numRows == dense.numRows && numCols == dense.numCols && data.sameElements(dense.data)
case _ => super.equals(obj)
}
}
/** Element wise update function
*
* @param row row index
* @param col column index
* @param value value to set at (row, col)
*/
override def update(row: Int, col: Int, value: Double): Unit = {
val index = locate(row, col)
data(index) = value
}
def toSparseMatrix: SparseMatrix = {
val entries = for(row <- 0 until numRows; col <- 0 until numCols) yield {
(row, col, apply(row, col))
}
SparseMatrix.fromCOO(numRows, numCols, entries.filter(_._3 != 0))
}
/** Calculates the linear index of the respective matrix entry
*
* @param row
* @param col
* @return
*/
private def locate(row: Int, col: Int): Int = {
require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
row + col * numRows
}
/** Converts the entry at (row, col) to string
*
* @param row
* @param col
* @return
*/
private def printEntry(row: Int, col: Int): String = {
val index = locate(row, col)
data(index).toString
}
/** Copies the matrix instance
*
* @return Copy of itself
*/
override def copy: DenseMatrix = {
new DenseMatrix(numRows, numCols, data.clone)
}
}
object DenseMatrix {
val LINE_WIDTH = 100
val MAX_ROWS = 50
def apply(numRows: Int, numCols: Int, values: Array[Int]): DenseMatrix = {
new DenseMatrix(numRows, numCols, values.map(_.toDouble))
}
......
......@@ -22,16 +22,16 @@ package org.apache.flink.ml.math
* Dense vector implementation of [[Vector]]. The data is represented in a continuous array of
* doubles.
*
* @param values Array of doubles to store the vector elements
* @param data Array of doubles to store the vector elements
*/
case class DenseVector(val values: Array[Double]) extends Vector {
case class DenseVector(val data: Array[Double]) extends Vector {
/**
* Number of elements in a vector
* @return
*/
override def size: Int = {
values.length
data.length
}
/**
......@@ -41,23 +41,19 @@ case class DenseVector(val values: Array[Double]) extends Vector {
* @return element at the given index
*/
override def apply(index: Int): Double = {
require(0 <= index && index < values.length, s"Index $index is out of bounds " +
s"[0, ${values.length})")
values(index)
require(0 <= index && index < data.length, s"Index $index is out of bounds " +
s"[0, ${data.length})")
data(index)
}
override def toString: String = {
s"DenseVector(${values.mkString(", ")})"
s"DenseVector(${data.mkString(", ")})"
}
override def equals(obj: Any): Boolean = {
obj match {
case dense: DenseVector =>
values.length == dense.values.length && values.zip(dense.values).forall{
case (a,b) => a == b
}
case _ => false
case dense: DenseVector => data.length == dense.data.length && data.sameElements(dense.data)
case _ => super.equals(obj)
}
}
......@@ -67,7 +63,25 @@ case class DenseVector(val values: Array[Double]) extends Vector {
* @return Copy of the vector instance
*/
override def copy: Vector = {
DenseVector(values.clone())
DenseVector(data.clone())
}
/** Updates the element at the given index with the provided value
*
* @param index
* @param value
*/
override def update(index: Int, value: Double): Unit = {
require(0 <= index && index < data.length, s"Index $index is out of bounds " +
s"[0, ${data.length})")
data(index) = value
}
def toSparseVector: SparseVector = {
val nonZero = (0 until size).zip(data).filter(_._2 != 0)
SparseVector.fromCOO(size, nonZero)
}
}
......
......@@ -18,28 +18,52 @@
package org.apache.flink.ml.math
/**
* Base trait for a matrix representation
*/
/** Base trait for a matrix representation
*
*/
trait Matrix {
/**
* Number of rows
* @return
*/
/** Number of rows
*
* @return
*/
def numRows: Int
/**
* Number of columns
* @return
*/
/** Number of columns
*
* @return
*/
def numCols: Int
/**
* Element wise access function
* @param row row index
* @param col column index
* @return matrix entry at (row, col)
*/
/** Element wise access function
*
* @param row row index
* @param col column index
* @return matrix entry at (row, col)
*/
def apply(row: Int, col: Int): Double
/** Element wise update function
*
* @param row row index
* @param col column index
* @param value value to set at (row, col)
*/
def update(row: Int, col: Int, value: Double): Unit
/** Copies the matrix instance
*
* @return Copy of itself
*/
def copy: Matrix
override def equals(obj: Any): Boolean = {
obj match {
case matrix: Matrix if numRows == matrix.numRows && numCols == matrix.numCols =>
val coordinates = for(row <- 0 until numRows; col <- 0 until numCols) yield (row, col)
coordinates forall { case(row, col) => this.apply(row, col) == matrix(row, col)}
case _ => false
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.math
import scala.util.Sorting
/** Sparse matrix using the compressed sparse column (CSC) representation.
*
* More details concerning the compressed sparse column (CSC) representation can be found
* [http://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_.28CSC_or_CCS.29].
*
* @param numRows Number of rows
* @param numCols Number of columns
* @param rowIndices Array containing the row indices of non-zero entries
* @param colPtrs Array containing the starting offsets in data for each column
* @param data Array containing the non-zero entries in column-major order
*/
class SparseMatrix(
val numRows: Int,
val numCols: Int,
val rowIndices: Array[Int],
val colPtrs: Array[Int],
val data: Array[Double])
extends Matrix {
/** Element wise access function
*
* @param row row index
* @param col column index
* @return matrix entry at (row, col)
*/
override def apply(row: Int, col: Int): Double = {
val index = locate(row, col)
if(index < 0){
0
} else {
data(index)
}
}
def toDenseMatrix: DenseMatrix = {
val result = DenseMatrix.zeros(numRows, numCols)
for(row <- 0 until numRows; col <- 0 until numCols) {
result(row, col) = apply(row, col)
}
result
}
/** Element wise update function
*
* @param row row index
* @param col column index
* @param value value to set at (row, col)
*/
override def update(row: Int, col: Int, value: Double): Unit = {
val index = locate(row, col)
if(index < 0) {
throw new IllegalArgumentException("Cannot update zero value of sparse matrix at index " +
s"($row, $col)")
} else {
data(index) = value
}
}
override def toString: String = {
val result = StringBuilder.newBuilder
result.append(s"SparseMatrix($numRows, $numCols)\n")
var columnIndex = 0
val fieldWidth = math.max(numRows, numCols).toString.length
val valueFieldWidth = data.map(_.toString.length).max + 2
for(index <- 0 until colPtrs.last) {
while(colPtrs(columnIndex + 1) <= index){
columnIndex += 1
}
val rowStr = rowIndices(index).toString
val columnStr = columnIndex.toString
val valueStr = data(index).toString
result.append("(" + " " * (fieldWidth - rowStr.length) + rowStr + "," +
" " * (fieldWidth - columnStr.length) + columnStr + ")")
result.append(" " * (valueFieldWidth - valueStr.length) + valueStr)
result.append("\n")
}
result.toString
}
private def locate(row: Int, col: Int): Int = {
require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
val startIndex = colPtrs(col)
val endIndex = colPtrs(col+1)
java.util.Arrays.binarySearch(rowIndices, startIndex, endIndex, row)
}
/** Copies the matrix instance
*
* @return Copy of itself
*/
override def copy: SparseMatrix = {
new SparseMatrix(numRows, numCols, rowIndices.clone, colPtrs.clone(), data.clone)
}
}
object SparseMatrix{
/** Constructs a sparse matrix from a coordinate list (COO) representation where each entry
* is stored as a tuple of (rowIndex, columnIndex, value).
* @param numRows
* @param numCols
* @param entries
* @return
*/
def fromCOO(numRows: Int, numCols: Int, entries: (Int, Int, Double)*): SparseMatrix = {
fromCOO(numRows, numCols, entries)
}
/** Constructs a sparse matrix from a coordinate list (COO) representation where each entry
* is stored as a tuple of (rowIndex, columnIndex, value).
*
* @param numRows
* @param numCols
* @param entries
* @return
*/
def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = {
val entryArray = entries.toArray
entryArray.foreach{ case (row, col, _) =>
require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
require(0 <= col && col < numCols, s"Columm $col is out of bounds [0, $numCols).")
}
val COOOrdering = new Ordering[(Int, Int, Double)] {
override def compare(x: (Int, Int, Double), y: (Int, Int, Double)): Int = {
if(x._2 < y._2) {
-1
} else if(x._2 > y._2) {
1
} else {
x._1 - y._1
}
}
}
Sorting.quickSort(entryArray)(COOOrdering)
val nnz = entryArray.length
val data = new Array[Double](nnz)
val rowIndices = new Array[Int](nnz)
val colPtrs = new Array[Int](numCols + 1)
var (lastRow, lastCol, lastValue) = entryArray(0)
rowIndices(0) = lastRow
data(0) = lastValue
var i = 1
var lastDataIndex = 0
while(i < nnz) {
val (curRow, curCol, curValue) = entryArray(i)
if(lastRow == curRow && lastCol == curCol) {
// add values with identical coordinates
data(lastDataIndex) += curValue
} else {
lastDataIndex += 1
data(lastDataIndex) = curValue
rowIndices(lastDataIndex) = curRow
lastRow = curRow
}
while(lastCol < curCol) {
lastCol += 1
colPtrs(lastCol) = lastDataIndex
}
i += 1
}
lastDataIndex += 1
while(lastCol < numCols) {
colPtrs(lastCol + 1) = lastDataIndex
lastCol += 1
}
val prunedRowIndices = if(lastDataIndex < nnz) {
val prunedArray = new Array[Int](lastDataIndex)
rowIndices.copyToArray(prunedArray)
prunedArray
} else {
rowIndices
}
val prunedData = if(lastDataIndex < nnz) {
val prunedArray = new Array[Double](lastDataIndex)
data.copyToArray(prunedArray)
prunedArray
} else {
data
}
new SparseMatrix(numRows, numCols, prunedRowIndices, colPtrs, prunedData)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.math
import scala.util.Sorting
/** Sparse vector implementation storing the data in two arrays. One index contains the sorted
* indices of the non-zero vector entries and the other the corresponding vector entries
*/
class SparseVector(
val size: Int,
val indices: Array[Int],
val data: Array[Double])
extends Vector {
/** Updates the element at the given index with the provided value
*
* @param index
* @param value
*/
override def update(index: Int, value: Double): Unit = {
val resolvedIndex = locate(index)
if (resolvedIndex < 0) {
throw new IllegalArgumentException("Cannot update zero value of sparse vector at index " +
index)
} else {
data(resolvedIndex) = value
}
}
/** Copies the vector instance
*
* @return Copy of the vector instance
*/
override def copy: Vector = {
new SparseVector(size, indices.clone, data.clone)
}
/** Element wise access function
*
* * @param index index of the accessed element
* @return element with index
*/
override def apply(index: Int): Double = {
val resolvedIndex = locate(index)
if(resolvedIndex < 0) {
0
} else {
data(resolvedIndex)
}
}
def toDenseVector: DenseVector = {
val denseVector = DenseVector.zeros(size)
for(index <- 0 until size) {
denseVector(index) = this(index)
}
denseVector
}
private def locate(index: Int): Int = {
require(0 <= index && index < size, s"Index $index is out of bounds [0, $size).")
java.util.Arrays.binarySearch(indices, 0, indices.length, index)
}
}
object SparseVector {
/** Constructs a sparse vector from a coordinate list (COO) representation where each entry
* is stored as a tuple of (index, value).
*
* @param size
* @param entries
* @return
*/
def fromCOO(size: Int, entries: (Int, Double)*): SparseVector = {
fromCOO(size, entries)
}
/** Constructs a sparse vector from a coordinate list (COO) representation where each entry
* is stored as a tuple of (index, value).
*
* @param size
* @param entries
* @return
*/
def fromCOO(size: Int, entries: Iterable[(Int, Double)]): SparseVector = {
val entryArray = entries.toArray
val COOOrdering = new Ordering[(Int, Double)] {
override def compare(x: (Int, Double), y: (Int, Double)): Int = {
x._1 - y._1
}
}
Sorting.quickSort(entryArray)(COOOrdering)
// calculate size of the array
val arraySize = entryArray.foldLeft((-1, 0)){ case ((lastIndex, numRows), (index, _)) =>
if(lastIndex == index) {
(lastIndex, numRows)
} else {
(index, numRows + 1)
}
}._2
val indices = new Array[Int](arraySize)
val data = new Array[Double](arraySize)
val (index, value) = entryArray(0)
indices(0) = index
data(0) = value
var i = 1
var lastIndex = indices(0)
var lastDataIndex = 0
while(i < entryArray.length) {
val (curIndex, curValue) = entryArray(i)
if(curIndex == lastIndex) {
data(lastDataIndex) += curValue
} else {
lastDataIndex += 1
data(lastDataIndex) = curValue
indices(lastDataIndex) = curIndex
lastIndex = curIndex
}
i += 1
}
new SparseVector(size, indices, data)
}
}
......@@ -18,29 +18,45 @@
package org.apache.flink.ml.math
/**
* Base trait for Vectors
*/
/** Base trait for Vectors
*
*/
trait Vector {
/**
* Number of elements in a vector
* @return
*/
/** Number of elements in a vector
*
* @return
*/
def size: Int
/**
* Element wise access function
*
* @param index index of the accessed element
* @return element with index
*/
/** Element wise access function
*
* * @param index index of the accessed element
* @return element with index
*/
def apply(index: Int): Double
/**
* Copies the vector instance
*
* @return Copy of the vector instance
*/
/** Updates the element at the given index with the provided value
*
* @param index
* @param value
*/
def update(index: Int, value: Double): Unit
/** Copies the vector instance
*
* @return Copy of the vector instance
*/
def copy: Vector
override def equals(obj: Any): Boolean = {
obj match {
case vector: Vector if size == vector.size =>
0 until size forall { idx =>
this(idx) == vector(idx)
}
case _ => false
}
}
}
......@@ -27,7 +27,7 @@ package object math {
override def iterator: Iterator[Double] = {
matrix match {
case dense: DenseMatrix => dense.values.iterator
case dense: DenseMatrix => dense.data.iterator
}
}
}
......@@ -35,14 +35,14 @@ package object math {
implicit class RichVector(vector: Vector) extends Iterable[Double] {
override def iterator: Iterator[Double] = {
vector match {
case dense: DenseVector => dense.values.iterator
case dense: DenseVector => dense.data.iterator
}
}
}
implicit def vector2Array(vector: Vector): Array[Double] = {
vector match {
case dense: DenseVector => dense.values
case dense: DenseVector => dense.data
}
}
}
......@@ -18,13 +18,13 @@
package org.apache.flink.ml.math
import org.scalatest.FlatSpec
import org.junit.Test
import org.scalatest.ShouldMatchers
class DenseMatrixSuite extends FlatSpec {
class DenseMatrixTest extends ShouldMatchers {
behavior of "A DenseMatrix"
it should "contain the initialization data after intialization" in {
@Test
def testDataAfterInitialization: Unit = {
val numRows = 10
val numCols = 13
......@@ -40,7 +40,8 @@ class DenseMatrixSuite extends FlatSpec {
}
}
it should "throw an IllegalArgumentException in case of an invalid index access" in {
@Test
def testIllegalArgumentExceptionInCaseOfInvalidIndexAccess: Unit = {
val numRows = 10
val numCols = 13
......@@ -66,4 +67,23 @@ class DenseMatrixSuite extends FlatSpec {
matrix(numRows, numCols)
}
}
@Test
def testCopy: Unit = {
val numRows = 4
val numCols = 5
val data = Array.range(0, numRows*numCols)
val denseMatrix = DenseMatrix.apply(numRows, numCols, data)
val copy = denseMatrix.copy
denseMatrix should equal(copy)
copy(0, 0) = 1
denseMatrix should not equal(copy)
}
}
......@@ -18,13 +18,14 @@
package org.apache.flink.ml.math
import org.scalatest.FlatSpec
import org.junit.Test
import org.scalatest.ShouldMatchers
class DenseVectorSuite extends FlatSpec {
behavior of "A DenseVector"
class DenseVectorTest extends ShouldMatchers {
it should "contain the initialization data after initialization" in {
@Test
def testDataAfterInitialization {
val data = Array.range(1,10)
val vector = DenseVector(data)
......@@ -34,7 +35,8 @@ class DenseVectorSuite extends FlatSpec {
data.zip(vector).foreach{case (expected, actual) => assertResult(expected)(actual)}
}
it should "throw an IllegalArgumentException in case of an illegal index access" in {
@Test
def testIllegalArgumentExceptionInCaseOfIllegalIndexAccess {
val size = 10
val vector = DenseVector.zeros(size)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.math
import org.junit.Test
import org.scalatest.ShouldMatchers
class SparseMatrixTest extends ShouldMatchers {
@Test
def testSparseMatrixFromCOO: Unit = {
val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
(3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
val expectedSparseMatrix = SparseMatrix.fromCOO(5, 5, (3, 4, 42), (2, 1, 17), (3, 3, 88),
(4, 2, 99), (1, 4, 91))
val expectedDenseMatrix = DenseMatrix.zeros(5, 5)
expectedDenseMatrix(3, 4) = 42
expectedDenseMatrix(2, 1) = 17
expectedDenseMatrix(3, 3) = 88
expectedDenseMatrix(4, 2) = 99
expectedDenseMatrix(1, 4) = 91
sparseMatrix should equal(expectedSparseMatrix)
sparseMatrix should equal(expectedDenseMatrix)
sparseMatrix.toDenseMatrix.data.sameElements(expectedDenseMatrix.data) should be(true)
sparseMatrix(0, 1) = 10
intercept[IllegalArgumentException]{
sparseMatrix(1, 1) = 1
}
}
@Test
def testInvalidIndexAccess: Unit = {
val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
intercept[IllegalArgumentException] {
sparseVector(-1)
}
intercept[IllegalArgumentException] {
sparseVector(5)
}
sparseVector(0) should equal(0)
sparseVector(3) should equal(3)
}
@Test
def testSparseMatrixFromCOOWithInvalidIndices: Unit = {
intercept[IllegalArgumentException]{
val sparseMatrix = SparseMatrix.fromCOO(5 ,5, (5, 0, 10), (0, 0, 0), (0, 1, 0), (3, 4, 43),
(2, 1, 17))
}
intercept[IllegalArgumentException]{
val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
(-1, 4, 20))
}
}
@Test
def testSparseMatrixCopy: Unit = {
val sparseMatrix = SparseMatrix.fromCOO(4, 4, (0, 1, 2), (2, 3, 1), (2, 0, 42), (1, 3, 3))
val copy = sparseMatrix.copy
sparseMatrix should equal(copy)
copy(2, 3) = 2
sparseMatrix should not equal(copy)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.math
import org.junit.Test
import org.scalatest.ShouldMatchers
class SparseVectorTest extends ShouldMatchers{
@Test
def testDataAfterInitialization: Unit = {
val sparseVector = SparseVector.fromCOO(5, (0, 1), (2, 0), (4, 42), (0, 3))
val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42))
val expectedDenseVector = DenseVector.zeros(5)
expectedDenseVector(0) = 4
expectedDenseVector(4) = 42
sparseVector should equal(expectedSparseVector)
sparseVector should equal(expectedDenseVector)
val denseVector = sparseVector.toDenseVector
denseVector should equal(expectedDenseVector)
}
@Test
def testInvalidIndexAccess: Unit = {
val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 10), (3, 5))
intercept[IllegalArgumentException] {
sparseVector(-1)
}
intercept[IllegalArgumentException] {
sparseVector(5)
}
}
@Test
def testSparseVectorFromCOOWithInvalidIndices: Unit = {
intercept[IllegalArgumentException] {
val sparseVector = SparseVector.fromCOO(5, (0, 1), (-1, 34), (3, 2))
}
intercept[IllegalArgumentException] {
val sparseVector = SparseVector.fromCOO(5, (0, 1), (4,3), (5, 1))
}
}
@Test
def testSparseVectorCopy: Unit = {
val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 3), (3, 2))
val copy = sparseVector.copy
sparseVector should equal(copy)
copy(3) = 3
sparseVector should not equal(copy)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册