From 47b6b38ca8d9c5de794183cc91cbf6559ef27390 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 25 Jul 2014 14:34:38 -0700 Subject: [PATCH] [SPARK-2125] Add sort flag and move sort into shuffle implementations This patch adds a sort flag into ShuffleDependecy and moves sort into hash shuffle implementation. Moving sort into shuffle implementation can give space for other shuffle implementations (like sort-based shuffle) to better optimize sort through shuffle. Author: jerryshao Closes #1210 from jerryshao/SPARK-2125 and squashes the following commits: 2feaf7b [jerryshao] revert MimaExcludes ceddf75 [jerryshao] add MimaExeclude f674ff4 [jerryshao] Add missing Scope restriction b9fe0dd [jerryshao] Fix some style issues according to comments ef6b729 [jerryshao] Change sort flag into Option 3f6eeed [jerryshao] Fix issues related to unit test 2f552a5 [jerryshao] Minor changes about naming and order c92a281 [jerryshao] Move sort into shuffle implementations --- .../scala/org/apache/spark/Dependency.scala | 4 +++- .../apache/spark/rdd/OrderedRDDFunctions.scala | 17 ++++++++--------- .../org/apache/spark/rdd/ShuffledRDD.scala | 12 +++++++++++- .../spark/shuffle/hash/HashShuffleReader.scala | 14 +++++++++++++- 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 09a6057123..f010c03223 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleHandle @@ -62,7 +63,8 @@ class ShuffleDependency[K, V, C]( val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, - val mapSideCombine: Boolean = false) + val mapSideCombine: Boolean = false, + val sortOrder: Option[SortOrder] = None) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index f1f4b4324e..afd7075f68 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -57,14 +57,13 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, */ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { val part = new RangePartitioner(numPartitions, self, ascending) - val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering) - shuffled.mapPartitions(iter => { - val buf = iter.toArray - if (ascending) { - buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator - } else { - buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator - } - }, preservesPartitioning = true) + new ShuffledRDD[K, V, V, P](self, part) + .setKeyOrdering(ordering) + .setSortOrder(if (ascending) SortOrder.ASCENDING else SortOrder.DESCENDING) } } + +private[spark] object SortOrder extends Enumeration { + type SortOrder = Value + val ASCENDING, DESCENDING = Value +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index bf02f68d0d..da4a8c3dc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -21,6 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { @@ -51,6 +52,8 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( private var mapSideCombine: Boolean = false + private var sortOrder: Option[SortOrder] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = { this.serializer = Option(serializer) @@ -75,8 +78,15 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( this } + /** Set sort order for RDD's sorting. */ + def setSortOrder(sortOrder: SortOrder): ShuffledRDD[K, V, C, P] = { + this.sortOrder = Option(sortOrder) + this + } + override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) + List(new ShuffleDependency(prev, part, serializer, + keyOrdering, aggregator, mapSideCombine, sortOrder)) } override val partitioner = Some(part) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index d45258c0a4..76cdb8f4f8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.rdd.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} @@ -38,7 +39,7 @@ class HashShuffleReader[K, C]( val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, Serializer.getSerializer(dep.serializer)) - if (dep.aggregator.isDefined) { + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) } else { @@ -49,6 +50,17 @@ class HashShuffleReader[K, C]( } else { iter } + + val sortedIter = for (sortOrder <- dep.sortOrder; ordering <- dep.keyOrdering) yield { + val buf = aggregatedIter.toArray + if (sortOrder == SortOrder.ASCENDING) { + buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator + } else { + buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator + } + } + + sortedIter.getOrElse(aggregatedIter) } /** Close this reader */ -- GitLab