提交 2d25e348 编写于 作者: A Ankur Dave 提交者: Reynold Xin

Replace RoutingTableMessage with pair

RoutingTableMessage was used to construct routing tables to enable
joining VertexRDDs with partitioned edges. It stored three elements: the
destination vertex ID, the source edge partition, and a byte specifying
the position in which the edge partition referenced the vertex to enable
join elimination.

However, this was incompatible with sort-based shuffle (SPARK-2045). It
was also slightly wasteful, because partition IDs are usually much
smaller than 2^32, though this was mitigated by a custom serializer that
used variable-length encoding.

This commit replaces RoutingTableMessage with a pair of (VertexId, Int)
where the Int encodes both the source partition ID (in the lower 30
bits) and the position (in the top 2 bits).

Author: Ankur Dave <ankurdave@gmail.com>

Closes #1553 from ankurdave/remove-RoutingTableMessage and squashes the following commits:

697e17b [Ankur Dave] Replace RoutingTableMessage with pair
上级 60f0ae3d
...@@ -35,7 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator { ...@@ -35,7 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator {
def registerClasses(kryo: Kryo) { def registerClasses(kryo: Kryo) {
kryo.register(classOf[Edge[Object]]) kryo.register(classOf[Edge[Object]])
kryo.register(classOf[RoutingTableMessage])
kryo.register(classOf[(VertexId, Object)]) kryo.register(classOf[(VertexId, Object)])
kryo.register(classOf[EdgePartition[Object, Object]]) kryo.register(classOf[EdgePartition[Object, Object]])
kryo.register(classOf[BitSet]) kryo.register(classOf[BitSet])
......
...@@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector} ...@@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector}
import org.apache.spark.graphx._ import org.apache.spark.graphx._
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
/** import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
* A message from the edge partition `pid` to the vertex partition containing `vid` specifying that
* the edge partition references `vid` in the specified `position` (src, dst, or both).
*/
private[graphx]
class RoutingTableMessage(
var vid: VertexId,
var pid: PartitionID,
var position: Byte)
extends Product2[VertexId, (PartitionID, Byte)] with Serializable {
override def _1 = vid
override def _2 = (pid, position)
override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage]
}
private[graphx] private[graphx]
class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
/** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage]( new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
self, partitioner).setSerializer(new RoutingTableMessageSerializer) self, partitioner).setSerializer(new RoutingTableMessageSerializer)
} }
} }
...@@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions { ...@@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions {
private[graphx] private[graphx]
object RoutingTablePartition { object RoutingTablePartition {
/**
* A message from an edge partition to a vertex specifying the position in which the edge
* partition references the vertex (src, dst, or both). The edge partition is encoded in the lower
* 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.
*/
type RoutingTableMessage = (VertexId, Int)
private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = {
val positionUpper2 = position << 30
val pidLower30 = pid & 0x3FFFFFFF
(vid, positionUpper2 | pidLower30)
}
private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1
private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF
private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte
val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty) val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty)
/** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */ /** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */
...@@ -77,7 +81,9 @@ object RoutingTablePartition { ...@@ -77,7 +81,9 @@ object RoutingTablePartition {
map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte) map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
} }
map.iterator.map { vidAndPosition => map.iterator.map { vidAndPosition =>
new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2) val vid = vidAndPosition._1
val position = vidAndPosition._2
toMessage(vid, pid, position)
} }
} }
...@@ -88,9 +94,12 @@ object RoutingTablePartition { ...@@ -88,9 +94,12 @@ object RoutingTablePartition {
val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
for (msg <- iter) { for (msg <- iter) {
pid2vid(msg.pid) += msg.vid val vid = vidFromMessage(msg)
srcFlags(msg.pid) += (msg.position & 0x1) != 0 val pid = pidFromMessage(msg)
dstFlags(msg.pid) += (msg.position & 0x2) != 0 val position = positionFromMessage(msg)
pid2vid(pid) += vid
srcFlags(pid) += (position & 0x1) != 0
dstFlags(pid) += (position & 0x2) != 0
} }
new RoutingTablePartition(pid2vid.zipWithIndex.map { new RoutingTablePartition(pid2vid.zipWithIndex.map {
......
...@@ -24,9 +24,11 @@ import java.nio.ByteBuffer ...@@ -24,9 +24,11 @@ import java.nio.ByteBuffer
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.graphx._
import org.apache.spark.serializer._ import org.apache.spark.serializer._
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
private[graphx] private[graphx]
class RoutingTableMessageSerializer extends Serializer with Serializable { class RoutingTableMessageSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
...@@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable { ...@@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
new ShuffleSerializationStream(s) { new ShuffleSerializationStream(s) {
def writeObject[T: ClassTag](t: T): SerializationStream = { def writeObject[T: ClassTag](t: T): SerializationStream = {
val msg = t.asInstanceOf[RoutingTableMessage] val msg = t.asInstanceOf[RoutingTableMessage]
writeVarLong(msg.vid, optimizePositive = false) writeVarLong(msg._1, optimizePositive = false)
writeUnsignedVarInt(msg.pid) writeInt(msg._2)
// TODO: Write only the bottom two bits of msg.position
s.write(msg.position)
this this
} }
} }
...@@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable { ...@@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
new ShuffleDeserializationStream(s) { new ShuffleDeserializationStream(s) {
override def readObject[T: ClassTag](): T = { override def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false) val a = readVarLong(optimizePositive = false)
val b = readUnsignedVarInt() val b = readInt()
val c = s.read() (a, b).asInstanceOf[T]
if (c == -1) throw new EOFException
new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T]
} }
} }
} }
......
...@@ -30,7 +30,7 @@ package object graphx { ...@@ -30,7 +30,7 @@ package object graphx {
*/ */
type VertexId = Long type VertexId = Long
/** Integer identifer of a graph partition. */ /** Integer identifer of a graph partition. Must be less than 2^30. */
// TODO: Consider using Char. // TODO: Consider using Char.
type PartitionID = Int type PartitionID = Int
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册