提交 3a69c72e 编写于 作者: D Doris Xin 提交者: Xiangrui Meng

[SPARK-2679] [MLLib] Ser/De for Double

Added a set of serializer/deserializer for Double in _common.py and PythonMLLibAPI in MLLib.

Author: Doris Xin <doris.s.xin@gmail.com>

Closes #1581 from dorx/doubleSerDe and squashes the following commits:

86a85b3 [Doris Xin] Merge branch 'master' into doubleSerDe
2bfe7a4 [Doris Xin] Removed magic byte
ad4d0d9 [Doris Xin] removed a space in unit
a9020bc [Doris Xin] units passed
7dad9af [Doris Xin] WIP
上级 aaf2b735
......@@ -54,6 +54,13 @@ class PythonMLLibAPI extends Serializable {
}
}
private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
require(bytes.length - offset == 8, "Wrong size byte array for Double")
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
bb.getDouble
}
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short")
......@@ -89,6 +96,22 @@ class PythonMLLibAPI extends Serializable {
Vectors.sparse(size, indices, values)
}
/**
* Returns an 8-byte array for the input Double.
*
* Note: we currently do not use a magic byte for double for storage efficiency.
* This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
* The corresponding deserializer, deserializeDouble, needs to be modified as well if the
* serialization scheme changes.
*/
private[python] def serializeDouble(double: Double): Array[Byte] = {
val bytes = new Array[Byte](8)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.putDouble(double)
bytes
}
private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length
val bytes = new Array[Byte](5 + 8 * len)
......
......@@ -57,4 +57,12 @@ class PythonMLLibAPISuite extends FunSuite {
assert(q.features === p.features)
}
}
test("double serialization") {
for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) {
val bytes = py.serializeDouble(x)
val deser = py.deserializeDouble(bytes)
assert(x === deser)
}
}
}
......@@ -72,9 +72,9 @@ except:
# Python interpreter must agree on what endian the machine is.
DENSE_VECTOR_MAGIC = 1
DENSE_VECTOR_MAGIC = 1
SPARSE_VECTOR_MAGIC = 2
DENSE_MATRIX_MAGIC = 3
DENSE_MATRIX_MAGIC = 3
LABELED_POINT_MAGIC = 4
......@@ -97,8 +97,28 @@ def _deserialize_numpy_array(shape, ba, offset, dtype=float64):
return ar.copy()
def _serialize_double(d):
"""
Serialize a double (float or numpy.float64) into a mutually understood format.
"""
if type(d) == float or type(d) == float64:
d = float64(d)
ba = bytearray(8)
_copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64)
return ba
else:
raise TypeError("_serialize_double called on non-float input")
def _serialize_double_vector(v):
"""Serialize a double vector into a mutually understood format.
"""
Serialize a double vector into a mutually understood format.
Note: we currently do not use a magic byte for double for storage
efficiency. This should be reconsidered when we add Ser/De for other
8-byte types (e.g. Long), for safety. The corresponding deserializer,
_deserialize_double, needs to be modified as well if the serialization
scheme changes.
>>> x = array([1,2,3])
>>> y = _deserialize_double_vector(_serialize_double_vector(x))
......@@ -148,6 +168,28 @@ def _serialize_sparse_vector(v):
return ba
def _deserialize_double(ba, offset=0):
"""Deserialize a double from a mutually understood format.
>>> import sys
>>> _deserialize_double(_serialize_double(123.0)) == 123.0
True
>>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0
True
>>> x = sys.float_info.max
>>> _deserialize_double(_serialize_double(sys.float_info.max)) == x
True
>>> y = float64(sys.float_info.max)
>>> _deserialize_double(_serialize_double(sys.float_info.max)) == y
True
"""
if type(ba) != bytearray:
raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba))
if len(ba) - offset != 8:
raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb)
return struct.unpack("d", ba[offset:])[0]
def _deserialize_double_vector(ba, offset=0):
"""Deserialize a double vector from a mutually understood format.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册