提交 3a69c72e 编写于 作者: D Doris Xin 提交者: Xiangrui Meng

[SPARK-2679] [MLLib] Ser/De for Double

Added a set of serializer/deserializer for Double in _common.py and PythonMLLibAPI in MLLib.

Author: Doris Xin <doris.s.xin@gmail.com>

Closes #1581 from dorx/doubleSerDe and squashes the following commits:

86a85b3 [Doris Xin] Merge branch 'master' into doubleSerDe
2bfe7a4 [Doris Xin] Removed magic byte
ad4d0d9 [Doris Xin] removed a space in unit
a9020bc [Doris Xin] units passed
7dad9af [Doris Xin] WIP
上级 aaf2b735
...@@ -54,6 +54,13 @@ class PythonMLLibAPI extends Serializable { ...@@ -54,6 +54,13 @@ class PythonMLLibAPI extends Serializable {
} }
} }
private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
require(bytes.length - offset == 8, "Wrong size byte array for Double")
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
bb.getDouble
}
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short") require(packetLength >= 5, "Byte array too short")
...@@ -89,6 +96,22 @@ class PythonMLLibAPI extends Serializable { ...@@ -89,6 +96,22 @@ class PythonMLLibAPI extends Serializable {
Vectors.sparse(size, indices, values) Vectors.sparse(size, indices, values)
} }
/**
* Returns an 8-byte array for the input Double.
*
* Note: we currently do not use a magic byte for double for storage efficiency.
* This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
* The corresponding deserializer, deserializeDouble, needs to be modified as well if the
* serialization scheme changes.
*/
private[python] def serializeDouble(double: Double): Array[Byte] = {
val bytes = new Array[Byte](8)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.putDouble(double)
bytes
}
private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length val len = doubles.length
val bytes = new Array[Byte](5 + 8 * len) val bytes = new Array[Byte](5 + 8 * len)
......
...@@ -57,4 +57,12 @@ class PythonMLLibAPISuite extends FunSuite { ...@@ -57,4 +57,12 @@ class PythonMLLibAPISuite extends FunSuite {
assert(q.features === p.features) assert(q.features === p.features)
} }
} }
test("double serialization") {
for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) {
val bytes = py.serializeDouble(x)
val deser = py.deserializeDouble(bytes)
assert(x === deser)
}
}
} }
...@@ -72,9 +72,9 @@ except: ...@@ -72,9 +72,9 @@ except:
# Python interpreter must agree on what endian the machine is. # Python interpreter must agree on what endian the machine is.
DENSE_VECTOR_MAGIC = 1 DENSE_VECTOR_MAGIC = 1
SPARSE_VECTOR_MAGIC = 2 SPARSE_VECTOR_MAGIC = 2
DENSE_MATRIX_MAGIC = 3 DENSE_MATRIX_MAGIC = 3
LABELED_POINT_MAGIC = 4 LABELED_POINT_MAGIC = 4
...@@ -97,8 +97,28 @@ def _deserialize_numpy_array(shape, ba, offset, dtype=float64): ...@@ -97,8 +97,28 @@ def _deserialize_numpy_array(shape, ba, offset, dtype=float64):
return ar.copy() return ar.copy()
def _serialize_double(d):
"""
Serialize a double (float or numpy.float64) into a mutually understood format.
"""
if type(d) == float or type(d) == float64:
d = float64(d)
ba = bytearray(8)
_copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64)
return ba
else:
raise TypeError("_serialize_double called on non-float input")
def _serialize_double_vector(v): def _serialize_double_vector(v):
"""Serialize a double vector into a mutually understood format. """
Serialize a double vector into a mutually understood format.
Note: we currently do not use a magic byte for double for storage
efficiency. This should be reconsidered when we add Ser/De for other
8-byte types (e.g. Long), for safety. The corresponding deserializer,
_deserialize_double, needs to be modified as well if the serialization
scheme changes.
>>> x = array([1,2,3]) >>> x = array([1,2,3])
>>> y = _deserialize_double_vector(_serialize_double_vector(x)) >>> y = _deserialize_double_vector(_serialize_double_vector(x))
...@@ -148,6 +168,28 @@ def _serialize_sparse_vector(v): ...@@ -148,6 +168,28 @@ def _serialize_sparse_vector(v):
return ba return ba
def _deserialize_double(ba, offset=0):
"""Deserialize a double from a mutually understood format.
>>> import sys
>>> _deserialize_double(_serialize_double(123.0)) == 123.0
True
>>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0
True
>>> x = sys.float_info.max
>>> _deserialize_double(_serialize_double(sys.float_info.max)) == x
True
>>> y = float64(sys.float_info.max)
>>> _deserialize_double(_serialize_double(sys.float_info.max)) == y
True
"""
if type(ba) != bytearray:
raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba))
if len(ba) - offset != 8:
raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb)
return struct.unpack("d", ba[offset:])[0]
def _deserialize_double_vector(ba, offset=0): def _deserialize_double_vector(ba, offset=0):
"""Deserialize a double vector from a mutually understood format. """Deserialize a double vector from a mutually understood format.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册