提交 8573bf46 编写于 作者: A Aljoscha Krettek

[scala] Simplify Operation Classes: No more Impl Classes

上级 15e59906
......@@ -135,7 +135,7 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
return this.keys2;
}
protected JoinHint getJoinHint() {
public JoinHint getJoinHint() {
return this.joinHint;
}
......
......@@ -22,6 +22,12 @@ import org.apache.flink.api.scala.operators.ScalaAggregateOperator
import scala.reflect.ClassTag
/**
* The result of [[DataSet.aggregate]]. This can be used to chain more aggregations to the
* one aggregate operator.
*
* @tparam T The type of the DataSet, i.e., the type of the elements of the DataSet.
*/
class AggregateDataSet[T: ClassTag](set: ScalaAggregateOperator[T])
extends DataSet[T](set) {
......
......@@ -30,7 +30,6 @@ import org.apache.flink.api.java.aggregation.Aggregations
import org.apache.flink.api.java.operators._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.util.Collector
import org.apache.flink.api.java.{DataSet => JavaDataSet}
import scala.collection.mutable
import scala.reflect.ClassTag
......@@ -42,126 +41,21 @@ import scala.reflect.ClassTag
* A secondary sort order can be added with sortGroup, but this is only used when using one
* of the group-at-a-time operations, i.e. `reduceGroup`.
*/
trait GroupedDataSet[T] {
class GroupedDataSet[T: ClassTag](
private val set: DataSet[T],
private val keys: Keys[T]) {
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
* use one of the group-at-a-time, i.e. `reduceGroup`.
*
* This only works on Tuple DataSets.
*/
def sortGroup(field: Int, order: Order): GroupedDataSet[T]
// These are for optional secondary sort. They are only used
// when using a group-at-a-time reduce function.
private val groupSortKeyPositions = mutable.MutableList[Int]()
private val groupSortOrders = mutable.MutableList[Order]()
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
* use one of the group-at-a-time, i.e. `reduceGroup`.
*
* This only works on CaseClass DataSets.
*/
def sortGroup(field: String, order: Order): GroupedDataSet[T]
/**
* Creates a new [[DataSet]] by aggregating the specified tuple field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
* tuples with the same key.
*
* This only works on Tuple DataSets.
*/
def aggregate(agg: Aggregations, field: Int): AggregateDataSet[T]
/**
* Creates a new [[DataSet]] by aggregating the specified field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
* elements with the same key.
*
* This only works on CaseClass DataSets.
*/
def aggregate(agg: Aggregations, field: String): AggregateDataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: Int): AggregateDataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `MAX`
*/
def max(field: Int): AggregateDataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `MIN`
*/
def min(field: Int): AggregateDataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: String): AggregateDataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `MAX`
*/
def max(field: String): AggregateDataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `MIN`
*/
def min(field: String): AggregateDataSet[T]
/**
* Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
* using an associative reduce function.
*/
def reduce(fun: (T, T) => T): DataSet[T]
/**
* Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
* using an associative reduce function.
*/
def reduce(reducer: ReduceFunction[T]): DataSet[T]
/**
* Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
* of elements to the group reduce function. The function must output one element. The
* concatenation of those will form the resulting [[DataSet]].
*/
def reduceGroup[R: TypeInformation: ClassTag](fun: (TraversableOnce[T]) => R): DataSet[R]
/**
* Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
* of elements to the group reduce function. The function can output zero or more elements using
* the [[Collector]]. The concatenation of the emitted values will form the resulting [[DataSet]].
*/
def reduceGroup[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T], Collector[R]) => Unit): DataSet[R]
/**
* Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
* of elements to the [[GroupReduceFunction]]. The function can output zero or more elements. The
* concatenation of the emitted values will form the resulting [[DataSet]].
*/
def reduceGroup[R: TypeInformation: ClassTag](reducer: GroupReduceFunction[T, R]): DataSet[R]
/**
* Creates a new DataSet containing the first `n` elements of each group of this DataSet.
*/
def first(n: Int): DataSet[T]
}
/**
* Private implementation for [[GroupedDataSet]] to keep the implementation details, i.e. the
* parameters of the constructor, hidden.
*/
private[flink] class GroupedDataSetImpl[T: ClassTag](
private val set: JavaDataSet[T],
private val keys: Keys[T])
extends GroupedDataSet[T] {
// These are for optional secondary sort. They are only used
// when using a group-at-a-time reduce function.
private val groupSortKeyPositions = mutable.MutableList[Int]()
private val groupSortOrders = mutable.MutableList[Order]()
def sortGroup(field: Int, order: Order): GroupedDataSet[T] = {
if (!set.getType.isTupleType) {
throw new InvalidProgramException("Specifying order keys via field positions is only valid " +
......@@ -175,6 +69,12 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
this
}
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
* use one of the group-at-a-time, i.e. `reduceGroup`.
*
* This only works on CaseClass DataSets.
*/
def sortGroup(field: String, order: Order): GroupedDataSet[T] = {
val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)
......@@ -183,56 +83,99 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
this
}
/**
* Creates a [[SortedGrouping]] if group sorting keys were specified.
*/
private def maybeCreateSortedGrouping(): Grouping[T] = {
if (groupSortKeyPositions.length > 0) {
val grouping = new SortedGrouping[T](set, keys, groupSortKeyPositions(0), groupSortOrders(0))
val grouping = new SortedGrouping[T](
set.javaSet,
keys,
groupSortKeyPositions(0),
groupSortOrders(0))
// now manually add the rest of the keys
for (i <- 1 until groupSortKeyPositions.length) {
grouping.sortGroup(groupSortKeyPositions(i), groupSortOrders(i))
}
grouping
} else {
new UnsortedGrouping[T](set, keys)
new UnsortedGrouping[T](set.javaSet, keys)
}
}
/** Convenience methods for creating the [[UnsortedGrouping]] */
private def createUnsortedGrouping(): Grouping[T] = new UnsortedGrouping[T](set, keys)
private def createUnsortedGrouping(): Grouping[T] = new UnsortedGrouping[T](set.javaSet, keys)
/**
* Creates a new [[DataSet]] by aggregating the specified tuple field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
* tuples with the same key.
*
* This only works on Tuple DataSets.
*/
def aggregate(agg: Aggregations, field: String): AggregateDataSet[T] = {
val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)
new AggregateDataSet(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, fieldIndex))
}
/**
* Creates a new [[DataSet]] by aggregating the specified field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
* elements with the same key.
*
* This only works on CaseClass DataSets.
*/
def aggregate(agg: Aggregations, field: Int): AggregateDataSet[T] = {
new AggregateDataSet(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, field))
}
/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: Int) = {
aggregate(Aggregations.SUM, field)
}
/**
* Syntactic sugar for [[aggregate]] with `MAX`
*/
def max(field: Int) = {
aggregate(Aggregations.MAX, field)
}
/**
* Syntactic sugar for [[aggregate]] with `MIN`
*/
def min(field: Int) = {
aggregate(Aggregations.MIN, field)
}
/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: String) = {
aggregate(Aggregations.SUM, field)
}
/**
* Syntactic sugar for [[aggregate]] with `MAX`
*/
def max(field: String) = {
aggregate(Aggregations.MAX, field)
}
/**
* Syntactic sugar for [[aggregate]] with `MIN`
*/
def min(field: String) = {
aggregate(Aggregations.MIN, field)
}
/**
* Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
* using an associative reduce function.
*/
def reduce(fun: (T, T) => T): DataSet[T] = {
Validate.notNull(fun, "Reduce function must not be null.")
val reducer = new ReduceFunction[T] {
......@@ -243,11 +186,20 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
wrap(new ReduceOperator[T](createUnsortedGrouping(), reducer))
}
/**
* Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
* using an associative reduce function.
*/
def reduce(reducer: ReduceFunction[T]): DataSet[T] = {
Validate.notNull(reducer, "Reduce function must not be null.")
wrap(new ReduceOperator[T](createUnsortedGrouping(), reducer))
}
/**
* Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
* of elements to the group reduce function. The function must output one element. The
* concatenation of those will form the resulting [[DataSet]].
*/
def reduceGroup[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T]) => R): DataSet[R] = {
Validate.notNull(fun, "Group reduce function must not be null.")
......@@ -261,6 +213,11 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
implicitly[TypeInformation[R]], reducer))
}
/**
* Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
* of elements to the group reduce function. The function can output zero or more elements using
* the [[Collector]]. The concatenation of the emitted values will form the resulting [[DataSet]].
*/
def reduceGroup[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T], Collector[R]) => Unit): DataSet[R] = {
Validate.notNull(fun, "Group reduce function must not be null.")
......@@ -274,6 +231,11 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
implicitly[TypeInformation[R]], reducer))
}
/**
* Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
* of elements to the [[GroupReduceFunction]]. The function can output zero or more elements. The
* concatenation of the emitted values will form the resulting [[DataSet]].
*/
def reduceGroup[R: TypeInformation: ClassTag](reducer: GroupReduceFunction[T, R]): DataSet[R] = {
Validate.notNull(reducer, "GroupReduce function must not be null.")
wrap(
......@@ -281,6 +243,9 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
implicitly[TypeInformation[R]], reducer))
}
/**
* Creates a new DataSet containing the first `n` elements of each group of this DataSet.
*/
def first(n: Int): DataSet[T] = {
if (n < 1) {
throw new InvalidProgramException("Parameter n of first(n) must be at least 1.")
......
......@@ -18,12 +18,10 @@
package org.apache.flink.api.scala
import org.apache.commons.lang3.Validate
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.common.functions.{RichCoGroupFunction, CoGroupFunction}
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.java.operators._
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo
import org.apache.flink.api.java.{DataSet => JavaDataSet}
import org.apache.flink.api.scala.typeutils.{CaseClassSerializer, CaseClassTypeInfo}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.util.Collector
......@@ -56,25 +54,63 @@ import scala.reflect.ClassTag
* }
* }}}
*
* @tparam T Type of the left input of the coGroup.
* @tparam O Type of the right input of the coGroup.
* @tparam L Type of the left input of the coGroup.
* @tparam R Type of the right input of the coGroup.
*/
trait CoGroupDataSet[T, O] extends DataSet[(Array[T], Array[O])] {
class CoGroupDataSet[L, R](
defaultCoGroup: CoGroupOperator[L, R, (Array[L], Array[R])],
leftInput: DataSet[L],
rightInput: DataSet[R],
leftKeys: Keys[L],
rightKeys: Keys[R])
extends DataSet(defaultCoGroup) {
/**
* Creates a new [[DataSet]] where the result for each pair of co-grouped element lists is the
* result of the given function.
*/
def apply[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T], TraversableOnce[O]) => R): DataSet[R]
def apply[O: TypeInformation: ClassTag](
fun: (TraversableOnce[L], TraversableOnce[R]) => O): DataSet[O] = {
Validate.notNull(fun, "CoGroup function must not be null.")
val coGrouper = new CoGroupFunction[L, R, O] {
def coGroup(left: java.lang.Iterable[L], right: java.lang.Iterable[R], out: Collector[O]) = {
out.collect(fun(left.iterator.asScala, right.iterator.asScala))
}
}
val coGroupOperator = new CoGroupOperator[L, R, O](
leftInput.javaSet,
rightInput.javaSet,
leftKeys,
rightKeys,
coGrouper,
implicitly[TypeInformation[O]])
wrap(coGroupOperator)
}
/**
* Creates a new [[DataSet]] where the result for each pair of co-grouped element lists is the
* result of the given function. The function can output zero or more elements using the
* [[Collector]] which will form the result.
*/
def apply[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T], TraversableOnce[O], Collector[R]) => Unit): DataSet[R]
def apply[O: TypeInformation: ClassTag](
fun: (TraversableOnce[L], TraversableOnce[R], Collector[O]) => Unit): DataSet[O] = {
Validate.notNull(fun, "CoGroup function must not be null.")
val coGrouper = new CoGroupFunction[L, R, O] {
def coGroup(left: java.lang.Iterable[L], right: java.lang.Iterable[R], out: Collector[O]) = {
fun(left.iterator.asScala, right.iterator.asScala, out)
}
}
val coGroupOperator = new CoGroupOperator[L, R, O](
leftInput.javaSet,
rightInput.javaSet,
leftKeys,
rightKeys,
coGrouper,
implicitly[TypeInformation[O]])
wrap(coGroupOperator)
}
/**
* Creates a new [[DataSet]] by passing each pair of co-grouped element lists to the given
......@@ -84,87 +120,45 @@ trait CoGroupDataSet[T, O] extends DataSet[(Array[T], Array[O])] {
* A [[RichCoGroupFunction]] can be used to access the
* broadcast variables and the [[org.apache.flink.api.common.functions.RuntimeContext]].
*/
def apply[R: TypeInformation: ClassTag](joiner: CoGroupFunction[T, O, R]): DataSet[R]
}
/**
* Private implementation for [[CoGroupDataSet]] to keep the implementation details, i.e. the
* parameters of the constructor, hidden.
*/
private[flink] class CoGroupDataSetImpl[T, O](
coGroupOperator: CoGroupOperator[T, O, (Array[T], Array[O])],
thisSet: DataSet[T],
otherSet: DataSet[O],
thisKeys: Keys[T],
otherKeys: Keys[O]) extends DataSet(coGroupOperator) with CoGroupDataSet[T, O] {
def apply[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T], TraversableOnce[O]) => R): DataSet[R] = {
Validate.notNull(fun, "CoGroup function must not be null.")
val coGrouper = new CoGroupFunction[T, O, R] {
def coGroup(left: java.lang.Iterable[T], right: java.lang.Iterable[O], out: Collector[R]) = {
out.collect(fun(left.iterator.asScala, right.iterator.asScala))
}
}
val coGroupOperator = new CoGroupOperator[T, O, R](thisSet.set, otherSet.set, thisKeys,
otherKeys, coGrouper, implicitly[TypeInformation[R]])
wrap(coGroupOperator)
}
def apply[O: TypeInformation: ClassTag](coGrouper: CoGroupFunction[L, R, O]): DataSet[O] = {
Validate.notNull(coGrouper, "CoGroup function must not be null.")
val coGroupOperator = new CoGroupOperator[L, R, O](
leftInput.javaSet,
rightInput.javaSet,
leftKeys,
rightKeys,
coGrouper,
implicitly[TypeInformation[O]])
def apply[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T], TraversableOnce[O], Collector[R]) => Unit): DataSet[R] = {
Validate.notNull(fun, "CoGroup function must not be null.")
val coGrouper = new CoGroupFunction[T, O, R] {
def coGroup(left: java.lang.Iterable[T], right: java.lang.Iterable[O], out: Collector[R]) = {
fun(left.iterator.asScala, right.iterator.asScala, out)
}
}
val coGroupOperator = new CoGroupOperator[T, O, R](thisSet.set, otherSet.set, thisKeys,
otherKeys, coGrouper, implicitly[TypeInformation[R]])
wrap(coGroupOperator)
}
def apply[R: TypeInformation: ClassTag](joiner: CoGroupFunction[T, O, R]): DataSet[R] = {
Validate.notNull(joiner, "CoGroup function must not be null.")
val coGroupOperator = new CoGroupOperator[T, O, R](thisSet.set, otherSet.set, thisKeys,
otherKeys, joiner, implicitly[TypeInformation[R]])
wrap(coGroupOperator)
}
}
/**
* An unfinished coGroup operation that results from [[DataSet.coGroup()]] The keys for the left and
* An unfinished coGroup operation that results from [[DataSet.coGroup]] The keys for the left and
* right side must be specified using first `where` and then `isEqualTo`. For example:
*
* {{{
* val left = ...
* val right = ...
* val joinResult = left.coGroup(right).where(...).isEqualTo(...)
* val coGroupResult = left.coGroup(right).where(...).isEqualTo(...)
* }}}
* @tparam T The type of the left input of the coGroup.
* @tparam O The type of the right input of the coGroup.
* @tparam L The type of the left input of the coGroup.
* @tparam R The type of the right input of the coGroup.
*/
trait UnfinishedCoGroupOperation[T, O]
extends UnfinishedKeyPairOperation[T, O, CoGroupDataSet[T, O]]
class UnfinishedCoGroupOperation[L: ClassTag, R: ClassTag](
leftInput: DataSet[L],
rightInput: DataSet[R])
extends UnfinishedKeyPairOperation[L, R, CoGroupDataSet[L, R]](leftInput, rightInput) {
/**
* Private implementation for [[UnfinishedCoGroupOperation]] to keep the implementation details,
* i.e. the parameters of the constructor, hidden.
*/
private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
leftSet: DataSet[T],
rightSet: DataSet[O])
extends UnfinishedKeyPairOperation[T, O, CoGroupDataSet[T, O]](leftSet, rightSet)
with UnfinishedCoGroupOperation[T, O] {
private[flink] def finish(leftKey: Keys[T], rightKey: Keys[O]) = {
val coGrouper = new CoGroupFunction[T, O, (Array[T], Array[O])] {
private[flink] def finish(leftKey: Keys[L], rightKey: Keys[R]) = {
val coGrouper = new CoGroupFunction[L, R, (Array[L], Array[R])] {
def coGroup(
left: java.lang.Iterable[T],
right: java.lang.Iterable[O],
out: Collector[(Array[T], Array[O])]) = {
val leftResult = Array[Any](left.asScala.toSeq: _*).asInstanceOf[Array[T]]
val rightResult = Array[Any](right.asScala.toSeq: _*).asInstanceOf[Array[O]]
left: java.lang.Iterable[L],
right: java.lang.Iterable[R],
out: Collector[(Array[L], Array[R])]) = {
val leftResult = Array[Any](left.asScala.toSeq: _*).asInstanceOf[Array[L]]
val rightResult = Array[Any](right.asScala.toSeq: _*).asInstanceOf[Array[R]]
out.collect((leftResult, rightResult))
}
......@@ -173,59 +167,33 @@ private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
// We have to use this hack, for some reason classOf[Array[T]] does not work.
// Maybe because ObjectArrayTypeInfo does not accept the Scala Array as an array class.
val leftArrayType =
ObjectArrayTypeInfo.getInfoFor(new Array[T](0).getClass, leftSet.set.getType)
ObjectArrayTypeInfo.getInfoFor(new Array[L](0).getClass, leftInput.getType)
val rightArrayType =
ObjectArrayTypeInfo.getInfoFor(new Array[O](0).getClass, rightSet.set.getType)
ObjectArrayTypeInfo.getInfoFor(new Array[R](0).getClass, rightInput.getType)
val returnType = new CaseClassTypeInfo[(Array[T], Array[O])](
classOf[(Array[T], Array[O])], Seq(leftArrayType, rightArrayType), Array("_1", "_2")) {
val returnType = new CaseClassTypeInfo[(Array[L], Array[R])](
classOf[(Array[L], Array[R])], Seq(leftArrayType, rightArrayType), Array("_1", "_2")) {
override def createSerializer: TypeSerializer[(Array[T], Array[O])] = {
override def createSerializer: TypeSerializer[(Array[L], Array[R])] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
for (i <- 0 until getArity()) {
for (i <- 0 until getArity) {
fieldSerializers(i) = types(i).createSerializer
}
new CaseClassSerializer[(Array[T], Array[O])](
classOf[(Array[T], Array[O])],
new CaseClassSerializer[(Array[L], Array[R])](
classOf[(Array[L], Array[R])],
fieldSerializers) {
override def createInstance(fields: Array[AnyRef]) = {
(fields(0).asInstanceOf[Array[T]], fields(1).asInstanceOf[Array[O]])
(fields(0).asInstanceOf[Array[L]], fields(1).asInstanceOf[Array[R]])
}
}
}
}
val coGroupOperator = new CoGroupOperator[T, O, (Array[T], Array[O])](
leftSet.set, rightSet.set, leftKey, rightKey, coGrouper, returnType)
// sanity check solution set key mismatches
leftSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
leftKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
val positions: Array[Int] = keyFields.computeLogicalKeyPositions
solutionSet.checkJoinKeyFields(positions)
case _ =>
throw new InvalidProgramException("Currently, the solution set may only be joined " +
"with " +
"using tuple field positions.")
}
case _ =>
}
rightSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
rightKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
val positions: Array[Int] = keyFields.computeLogicalKeyPositions
solutionSet.checkJoinKeyFields(positions)
case _ =>
throw new InvalidProgramException("Currently, the solution set may only be joined " +
"with " +
"using tuple field positions.")
}
case _ =>
}
val coGroupOperator = new CoGroupOperator[L, R, (Array[L], Array[R])](
leftInput.javaSet, rightInput.javaSet, leftKey, rightKey, coGrouper, returnType)
DeltaIterationSanityCheck(leftInput, rightInput, leftKey, rightKey)
new CoGroupDataSetImpl(coGroupOperator, leftSet, rightSet, leftKey, rightKey)
new CoGroupDataSet(coGroupOperator, leftInput, rightInput, leftKey, rightKey)
}
}
......@@ -43,16 +43,33 @@ import scala.reflect.ClassTag
* }
* }}}
*
* @tparam T Type of the left input of the cross.
* @tparam O Type of the right input of the cross.
* @tparam L Type of the left input of the cross.
* @tparam R Type of the right input of the cross.
*/
trait CrossDataSet[T, O] extends DataSet[(T, O)] {
class CrossDataSet[L, R](
defaultCross: CrossOperator[L, R, (L, R)],
leftInput: DataSet[L],
rightInput: DataSet[R])
extends DataSet(defaultCross) {
/**
* Creates a new [[DataSet]] where the result for each pair of elements is the result
* of the given function.
*/
def apply[R: TypeInformation: ClassTag](fun: (T, O) => R): DataSet[R]
def apply[O: TypeInformation: ClassTag](fun: (L, R) => O): DataSet[O] = {
Validate.notNull(fun, "Cross function must not be null.")
val crosser = new CrossFunction[L, R, O] {
def cross(left: L, right: R): O = {
fun(left, right)
}
}
val crossOperator = new CrossOperator[L, R, O](
leftInput.javaSet,
rightInput.javaSet,
crosser,
implicitly[TypeInformation[O]])
wrap(crossOperator)
}
/**
* Creates a new [[DataSet]] by passing each pair of values to the given function.
......@@ -62,71 +79,50 @@ trait CrossDataSet[T, O] extends DataSet[(T, O)] {
* A [[RichCrossFunction]] can be used to access the
* broadcast variables and the [[org.apache.flink.api.common.functions.RuntimeContext]].
*/
def apply[R: TypeInformation: ClassTag](joiner: CrossFunction[T, O, R]): DataSet[R]
}
/**
* Private implementation for [[CrossDataSet]] to keep the implementation details, i.e. the
* parameters of the constructor, hidden.
*/
private[flink] class CrossDataSetImpl[T, O](
crossOperator: CrossOperator[T, O, (T, O)],
thisSet: JavaDataSet[T],
otherSet: JavaDataSet[O])
extends DataSet(crossOperator)
with CrossDataSet[T, O] {
def apply[R: TypeInformation: ClassTag](fun: (T, O) => R): DataSet[R] = {
Validate.notNull(fun, "Cross function must not be null.")
val crosser = new CrossFunction[T, O, R] {
def cross(left: T, right: O): R = {
fun(left, right)
}
}
val crossOperator = new CrossOperator[T, O, R](
thisSet,
otherSet,
crosser,
implicitly[TypeInformation[R]])
wrap(crossOperator)
}
def apply[R: TypeInformation: ClassTag](crosser: CrossFunction[T, O, R]): DataSet[R] = {
def apply[O: TypeInformation: ClassTag](crosser: CrossFunction[L, R, O]): DataSet[O] = {
Validate.notNull(crosser, "Cross function must not be null.")
val crossOperator = new CrossOperator[T, O, R](
thisSet,
otherSet,
val crossOperator = new CrossOperator[L, R, O](
leftInput.javaSet,
rightInput.javaSet,
crosser,
implicitly[TypeInformation[R]])
implicitly[TypeInformation[O]])
wrap(crossOperator)
}
}
private[flink] object CrossDataSetImpl {
def createCrossOperator[T, O](leftSet: JavaDataSet[T], rightSet: JavaDataSet[O]) = {
val crosser = new CrossFunction[T, O, (T, O)] {
def cross(left: T, right: O) = {
private[flink] object CrossDataSet {
/**
* Creates a default cross operation with Tuple2 as result.
*/
def createCrossOperator[L, R](leftInput: DataSet[L], rightInput: DataSet[R]) = {
val crosser = new CrossFunction[L, R, (L, R)] {
def cross(left: L, right: R) = {
(left, right)
}
}
val returnType = new CaseClassTypeInfo[(T, O)](
classOf[(T, O)], Seq(leftSet.getType, rightSet.getType), Array("_1", "_2")) {
val returnType = new CaseClassTypeInfo[(L, R)](
classOf[(L, R)], Seq(leftInput.getType, rightInput.getType), Array("_1", "_2")) {
override def createSerializer: TypeSerializer[(T, O)] = {
override def createSerializer: TypeSerializer[(L, R)] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
for (i <- 0 until getArity) {
fieldSerializers(i) = types(i).createSerializer
}
new CaseClassSerializer[(T, O)](classOf[(T, O)], fieldSerializers) {
new CaseClassSerializer[(L, R)](classOf[(L, R)], fieldSerializers) {
override def createInstance(fields: Array[AnyRef]) = {
(fields(0).asInstanceOf[T], fields(1).asInstanceOf[O])
(fields(0).asInstanceOf[L], fields(1).asInstanceOf[R])
}
}
}
}
val crossOperator = new CrossOperator[T, O, (T, O)](leftSet, rightSet, crosser, returnType)
val crossOperator = new CrossOperator[L, R, (L, R)](
leftInput.javaSet,
rightInput.javaSet,
crosser,
returnType)
new CrossDataSetImpl(crossOperator, leftSet, rightSet)
new CrossDataSet(crossOperator, leftInput, rightInput)
}
}
......@@ -24,7 +24,6 @@ import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.java.operators.JoinOperator.DefaultJoin.WrappingFlatJoinFunction
import org.apache.flink.api.java.operators.JoinOperator.{EquiJoin, JoinHint}
import org.apache.flink.api.java.operators._
import org.apache.flink.api.java.{DataSet => JavaDataSet}
import org.apache.flink.api.scala.typeutils.{CaseClassSerializer, CaseClassTypeInfo}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.util.Collector
......@@ -55,23 +54,63 @@ import scala.reflect.ClassTag
* }
* }}}
*
* @tparam T Type of the left input of the join.
* @tparam O Type of the right input of the join.
* @tparam L Type of the left input of the join.
* @tparam R Type of the right input of the join.
*/
trait JoinDataSet[T, O] extends DataSet[(T, O)] {
class JoinDataSet[L, R](
defaultJoin: EquiJoin[L, R, (L, R)],
leftInput: DataSet[L],
rightInput: DataSet[R],
leftKeys: Keys[L],
rightKeys: Keys[R])
extends DataSet(defaultJoin) {
/**
* Creates a new [[DataSet]] where the result for each pair of joined elements is the result
* of the given function.
*/
def apply[R: TypeInformation: ClassTag](fun: (T, O) => R): DataSet[R]
def apply[O: TypeInformation: ClassTag](fun: (L, R) => O): DataSet[O] = {
Validate.notNull(fun, "Join function must not be null.")
val joiner = new FlatJoinFunction[L, R, O] {
def join(left: L, right: R, out: Collector[O]) = {
out.collect(fun(left, right))
}
}
val joinOperator = new EquiJoin[L, R, O](
leftInput.javaSet,
rightInput.javaSet,
leftKeys,
rightKeys,
joiner,
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint)
wrap(joinOperator)
}
/**
* Creates a new [[DataSet]] by passing each pair of joined values to the given function.
* The function can output zero or more elements using the [[Collector]] which will form the
* result.
*/
def apply[R: TypeInformation: ClassTag](fun: (T, O, Collector[R]) => Unit): DataSet[R]
def apply[O: TypeInformation: ClassTag](fun: (L, R, Collector[O]) => Unit): DataSet[O] = {
Validate.notNull(fun, "Join function must not be null.")
val joiner = new FlatJoinFunction[L, R, O] {
def join(left: L, right: R, out: Collector[O]) = {
fun(left, right, out)
}
}
val joinOperator = new EquiJoin[L, R, O](
leftInput.javaSet,
rightInput.javaSet,
leftKeys,
rightKeys,
joiner,
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint)
wrap(joinOperator)
}
/**
* Creates a new [[DataSet]] by passing each pair of joined values to the given function.
......@@ -81,7 +120,20 @@ trait JoinDataSet[T, O] extends DataSet[(T, O)] {
* A [[RichFlatJoinFunction]] can be used to access the
* broadcast variables and the [[org.apache.flink.api.common.functions.RuntimeContext]].
*/
def apply[R: TypeInformation: ClassTag](joiner: FlatJoinFunction[T, O, R]): DataSet[R]
def apply[O: TypeInformation: ClassTag](joiner: FlatJoinFunction[L, R, O]): DataSet[O] = {
Validate.notNull(joiner, "Join function must not be null.")
val joinOperator = new EquiJoin[L, R, O](
leftInput.javaSet,
rightInput.javaSet,
leftKeys,
rightKeys,
joiner,
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint)
wrap(joinOperator)
}
/**
* Creates a new [[DataSet]] by passing each pair of joined values to the given function.
......@@ -90,60 +142,20 @@ trait JoinDataSet[T, O] extends DataSet[(T, O)] {
* A [[org.apache.flink.api.common.functions.RichJoinFunction]] can be used to access the
* broadcast variables and the [[org.apache.flink.api.common.functions.RuntimeContext]].
*/
def apply[R: TypeInformation: ClassTag](joiner: JoinFunction[T, O, R]): DataSet[R]
}
/**
* Private implementation for [[JoinDataSet]] to keep the implementation details, i.e. the
* parameters of the constructor, hidden.
*/
private[flink] class JoinDataSetImpl[T, O](
joinOperator: EquiJoin[T, O, (T, O)],
thisSet: JavaDataSet[T],
otherSet: JavaDataSet[O],
thisKeys: Keys[T],
otherKeys: Keys[O])
extends DataSet(joinOperator)
with JoinDataSet[T, O] {
def apply[R: TypeInformation: ClassTag](fun: (T, O) => R): DataSet[R] = {
def apply[O: TypeInformation: ClassTag](fun: JoinFunction[L, R, O]): DataSet[O] = {
Validate.notNull(fun, "Join function must not be null.")
val joiner = new FlatJoinFunction[T, O, R] {
def join(left: T, right: O, out: Collector[R]) = {
out.collect(fun(left, right))
}
}
val joinOperator = new EquiJoin[T, O, R](thisSet, otherSet, thisKeys,
otherKeys, joiner, implicitly[TypeInformation[R]], JoinHint.OPTIMIZER_CHOOSES)
wrap(joinOperator)
}
def apply[R: TypeInformation: ClassTag](fun: (T, O, Collector[R]) => Unit): DataSet[R] = {
Validate.notNull(fun, "Join function must not be null.")
val joiner = new FlatJoinFunction[T, O, R] {
def join(left: T, right: O, out: Collector[R]) = {
fun(left, right, out)
}
}
val joinOperator = new EquiJoin[T, O, R](thisSet, otherSet, thisKeys,
otherKeys, joiner, implicitly[TypeInformation[R]], JoinHint.OPTIMIZER_CHOOSES)
wrap(joinOperator)
}
def apply[R: TypeInformation: ClassTag](joiner: FlatJoinFunction[T, O, R]): DataSet[R] = {
Validate.notNull(joiner, "Join function must not be null.")
val joinOperator = new EquiJoin[T, O, R](thisSet, otherSet, thisKeys,
otherKeys, joiner, implicitly[TypeInformation[R]], JoinHint.OPTIMIZER_CHOOSES)
wrap(joinOperator)
}
val generatedFunction: FlatJoinFunction[L, R, O] = new WrappingFlatJoinFunction[L, R, O](fun)
def apply[R: TypeInformation: ClassTag](fun: JoinFunction[T, O, R]): DataSet[R] = {
Validate.notNull(fun, "Join function must not be null.")
val joinOperator = new EquiJoin[L, R, O](
leftInput.javaSet,
rightInput.javaSet,
leftKeys,
rightKeys,
generatedFunction, fun,
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint)
val generatedFunction: FlatJoinFunction[T, O, R] = new WrappingFlatJoinFunction[T, O, R](fun)
val joinOperator = new EquiJoin[T, O, R](thisSet, otherSet, thisKeys,
otherKeys, generatedFunction, fun, implicitly[TypeInformation[R]], JoinHint.OPTIMIZER_CHOOSES)
wrap(joinOperator)
}
}
......@@ -157,49 +169,59 @@ private[flink] class JoinDataSetImpl[T, O](
* val right = ...
* val joinResult = left.join(right).where(...).isEqualTo(...)
* }}}
* @tparam T The type of the left input of the join.
* @tparam O The type of the right input of the join.
*/
trait UnfinishedJoinOperation[T, O] extends UnfinishedKeyPairOperation[T, O, JoinDataSet[T, O]]
/**
* Private implementation for [[UnfinishedJoinOperation]] to keep the implementation details,
* i.e. the parameters of the constructor, hidden.
* @tparam L The type of the left input of the join.
* @tparam R The type of the right input of the join.
*/
private[flink] class UnfinishedJoinOperationImpl[T, O](
leftSet: DataSet[T],
rightSet: DataSet[O],
joinHint: JoinHint)
extends UnfinishedKeyPairOperation[T, O, JoinDataSet[T, O]](leftSet, rightSet)
with UnfinishedJoinOperation[T, O] {
private[flink] def finish(leftKey: Keys[T], rightKey: Keys[O]) = {
val joiner = new FlatJoinFunction[T, O, (T, O)] {
def join(left: T, right: O, out: Collector[(T, O)]) = {
class UnfinishedJoinOperation[L, R](
leftSet: DataSet[L],
rightSet: DataSet[R],
val joinHint: JoinHint)
extends UnfinishedKeyPairOperation[L, R, JoinDataSet[L, R]](leftSet, rightSet) {
private[flink] def finish(leftKey: Keys[L], rightKey: Keys[R]) = {
val joiner = new FlatJoinFunction[L, R, (L, R)] {
def join(left: L, right: R, out: Collector[(L, R)]) = {
out.collect((left, right))
}
}
val returnType = new CaseClassTypeInfo[(T, O)](
classOf[(T, O)], Seq(leftSet.set.getType, rightSet.set.getType), Array("_1", "_2")) {
val returnType = new CaseClassTypeInfo[(L, R)](
classOf[(L, R)], Seq(leftSet.getType, rightSet.getType), Array("_1", "_2")) {
override def createSerializer: TypeSerializer[(T, O)] = {
override def createSerializer: TypeSerializer[(L, R)] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
for (i <- 0 until getArity()) {
for (i <- 0 until getArity) {
fieldSerializers(i) = types(i).createSerializer
}
new CaseClassSerializer[(T, O)](classOf[(T, O)], fieldSerializers) {
new CaseClassSerializer[(L, R)](classOf[(L, R)], fieldSerializers) {
override def createInstance(fields: Array[AnyRef]) = {
(fields(0).asInstanceOf[T], fields(1).asInstanceOf[O])
(fields(0).asInstanceOf[L], fields(1).asInstanceOf[R])
}
}
}
}
val joinOperator = new EquiJoin[T, O, (T, O)](
leftSet.set, rightSet.set, leftKey, rightKey, joiner, returnType, joinHint)
val joinOperator = new EquiJoin[L, R, (L, R)](
leftSet.javaSet, rightSet.javaSet, leftKey, rightKey, joiner, returnType, joinHint)
DeltaIterationSanityCheck(leftSet, rightSet, leftKey, rightKey)
new JoinDataSet(joinOperator, leftSet, rightSet, leftKey, rightKey)
}
}
/**
* This checks whether joining/coGrouping with the DeltaIteration SolutionSet uses the
* same key given when creating the DeltaIteration.
*/
private[flink] object DeltaIterationSanityCheck {
def apply[L, R](
leftSet: DataSet[L],
rightSet: DataSet[R],
leftKey: Keys[L],
rightKey: Keys[R]) = {
// sanity check solution set key mismatches
leftSet.set match {
leftSet.javaSet match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
leftKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
......@@ -212,7 +234,7 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
}
case _ =>
}
rightSet.set match {
rightSet.javaSet match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
rightKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
......@@ -225,7 +247,5 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
}
case _ =>
}
new JoinDataSetImpl(joinOperator, leftSet.set, rightSet.set, leftKey, rightKey)
}
}
......@@ -49,7 +49,14 @@ package object scala {
fields: Array[String]): Array[Int] = {
typeInfo match {
case ti: CaseClassTypeInfo[_] =>
ti.getFieldIndices(fields)
val result = ti.getFieldIndices(fields)
if (result.contains(-1)) {
throw new IllegalArgumentException("Fields '" + fields.mkString(", ") +
"' are not valid for '" + ti.toString + "'.")
}
result
case _ =>
throw new UnsupportedOperationException("Specifying fields by name is only" +
......
......@@ -77,13 +77,10 @@ abstract class CaseClassTypeInfo[T <: Product](
}
def getFieldIndices(fields: Array[String]): Array[Int] = {
val result = fields map { x => fieldNames.indexOf(x) }
if (result.contains(-1)) {
throw new IllegalArgumentException("Fields '" + fields.mkString(", ") +
"' are not valid for " + clazz + " with fields '" + fieldNames.mkString(", ") + "'.")
}
result
fields map { x => fieldNames.indexOf(x) }
}
override def toString = "Scala " + super.toString
override def toString = clazz.getSimpleName + "(" + fieldNames.zip(types).map {
case (n, t) => n + ": " + t}
.mkString(", ") + ")"
}
......@@ -19,12 +19,10 @@
package org.apache.flink.api.scala
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.java.{DataSet => JavaDataSet}
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.operators.Keys
import org.apache.flink.api.java.operators.Keys.FieldPositionKeys
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.common.typeinfo.TypeInformation
/**
......@@ -35,19 +33,19 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
* This way, we have a central point where all the key-providing happens and don't need to change
* the specific operations if the supported key types change.
*
* We use the type parameter `R` to specify the type of the resulting operation. For join
* this would be `JoinDataSet[T, O]` and for coGroup it would be `CoGroupDataSet[T, O]`. This
* We use the type parameter `O` to specify the type of the resulting operation. For join
* this would be `JoinDataSet[L, R]` and for coGroup it would be `CoGroupDataSet[L, R]`. This
* way the user gets the correct type for the finished operation.
*
* @tparam T Type of the left input [[DataSet]].
* @tparam O Type of the right input [[DataSet]].
* @tparam R The type of the resulting Operation.
* @tparam L Type of the left input [[DataSet]].
* @tparam R Type of the right input [[DataSet]].
* @tparam O The type of the resulting Operation.
*/
private[flink] abstract class UnfinishedKeyPairOperation[T, O, R](
private[flink] val leftSet: DataSet[T],
private[flink] val rightSet: DataSet[O]) {
private[flink] abstract class UnfinishedKeyPairOperation[L, R, O](
private[flink] val leftInput: DataSet[L],
private[flink] val rightInput: DataSet[R]) {
private[flink] def finish(leftKey: Keys[T], rightKey: Keys[O]): R
private[flink] def finish(leftKey: Keys[L], rightKey: Keys[R]): O
/**
* Specify the key fields for the left side of the key based operation. This returns
......@@ -58,8 +56,8 @@ private[flink] abstract class UnfinishedKeyPairOperation[T, O, R](
* This only works on Tuple [[DataSet]].
*/
def where(leftKeys: Int*) = {
val leftKey = new FieldPositionKeys[T](leftKeys.toArray, leftSet.set.getType)
new HalfUnfinishedKeyPairOperation[T, O, R](this, leftKey)
val leftKey = new FieldPositionKeys[L](leftKeys.toArray, leftInput.getType)
new HalfUnfinishedKeyPairOperation[L, R, O](this, leftKey)
}
/**
......@@ -72,11 +70,11 @@ private[flink] abstract class UnfinishedKeyPairOperation[T, O, R](
*/
def where(firstLeftField: String, otherLeftFields: String*) = {
val fieldIndices = fieldNames2Indices(
leftSet.set.getType,
leftInput.getType,
firstLeftField +: otherLeftFields.toArray)
val leftKey = new FieldPositionKeys[T](fieldIndices, leftSet.set.getType)
new HalfUnfinishedKeyPairOperation[T, O, R](this, leftKey)
val leftKey = new FieldPositionKeys[L](fieldIndices, leftInput.getType)
new HalfUnfinishedKeyPairOperation[L, R, O](this, leftKey)
}
/**
......@@ -85,18 +83,18 @@ private[flink] abstract class UnfinishedKeyPairOperation[T, O, R](
* key for the right side. The result after specifying the right side key is the finished
* operation.
*/
def where[K: TypeInformation](fun: (T) => K) = {
def where[K: TypeInformation](fun: (L) => K) = {
val keyType = implicitly[TypeInformation[K]]
val keyExtractor = new KeySelector[T, K] {
def getKey(in: T) = fun(in)
val keyExtractor = new KeySelector[L, K] {
def getKey(in: L) = fun(in)
}
val leftKey = new Keys.SelectorFunctionKeys[T, K](keyExtractor, leftSet.set.getType, keyType)
new HalfUnfinishedKeyPairOperation[T, O, R](this, leftKey)
val leftKey = new Keys.SelectorFunctionKeys[L, K](keyExtractor, leftInput.getType, keyType)
new HalfUnfinishedKeyPairOperation[L, R, O](this, leftKey)
}
}
private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
unfinished: UnfinishedKeyPairOperation[T, O, R], leftKey: Keys[T]) {
private[flink] class HalfUnfinishedKeyPairOperation[L, R, O](
unfinished: UnfinishedKeyPairOperation[L, R, O], leftKey: Keys[L]) {
/**
* Specify the key fields for the right side of the key based operation. This returns
......@@ -104,8 +102,8 @@ private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
*
* This only works on a Tuple [[DataSet]].
*/
def equalTo(rightKeys: Int*): R = {
val rightKey = new FieldPositionKeys[O](rightKeys.toArray, unfinished.rightSet.set.getType)
def equalTo(rightKeys: Int*): O = {
val rightKey = new FieldPositionKeys[R](rightKeys.toArray, unfinished.rightInput.getType)
if (!leftKey.areCompatibale(rightKey)) {
throw new InvalidProgramException("The types of the key fields do not match. Left: " +
leftKey + " Right: " + rightKey)
......@@ -119,12 +117,12 @@ private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
*
* This only works on a CaseClass [[DataSet]].
*/
def equalTo(firstRightField: String, otherRightFields: String*): R = {
def equalTo(firstRightField: String, otherRightFields: String*): O = {
val fieldIndices = fieldNames2Indices(
unfinished.rightSet.set.getType,
unfinished.rightInput.getType,
firstRightField +: otherRightFields.toArray)
val rightKey = new FieldPositionKeys[O](fieldIndices, unfinished.rightSet.set.getType)
val rightKey = new FieldPositionKeys[R](fieldIndices, unfinished.rightInput.getType)
if (!leftKey.areCompatibale(rightKey)) {
throw new InvalidProgramException("The types of the key fields do not match. Left: " +
leftKey + " Right: " + rightKey)
......@@ -137,13 +135,16 @@ private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
* Specify the key selector function for the right side of the key based operation. This returns
* the finished operation.
*/
def equalTo[K: TypeInformation](fun: (O) => K) = {
def equalTo[K: TypeInformation](fun: (R) => K) = {
val keyType = implicitly[TypeInformation[K]]
val keyExtractor = new KeySelector[O, K] {
def getKey(in: O) = fun(in)
val keyExtractor = new KeySelector[R, K] {
def getKey(in: R) = fun(in)
}
val rightKey =
new Keys.SelectorFunctionKeys[O, K](keyExtractor, unfinished.rightSet.set.getType, keyType)
val rightKey = new Keys.SelectorFunctionKeys[R, K](
keyExtractor,
unfinished.rightInput.getType,
keyType)
if (!leftKey.areCompatibale(rightKey)) {
throw new InvalidProgramException("The types of the key fields do not match. Left: " +
leftKey + " Right: " + rightKey)
......
......@@ -48,8 +48,8 @@ class ReduceTranslationTest {
val sink: GenericDataSinkBase[_] = p.getDataSinks.iterator.next
val reducer: ReduceOperatorBase[_, _] = sink.getInput.asInstanceOf[ReduceOperatorBase[_, _]]
assertEquals(initialData.set.getType, reducer.getOperatorInfo.getInputType)
assertEquals(initialData.set.getType, reducer.getOperatorInfo.getOutputType)
assertEquals(initialData.javaSet.getType, reducer.getOperatorInfo.getInputType)
assertEquals(initialData.javaSet.getType, reducer.getOperatorInfo.getOutputType)
assertTrue(reducer.getKeyColumns(0) == null || reducer.getKeyColumns(0).length == 0)
assertTrue(reducer.getDegreeOfParallelism == 1 || reducer.getDegreeOfParallelism == -1)
assertTrue(reducer.getInput.isInstanceOf[GenericDataSourceBase[_, _]])
......@@ -77,8 +77,8 @@ class ReduceTranslationTest {
val sink: GenericDataSinkBase[_] = p.getDataSinks.iterator.next
val reducer: ReduceOperatorBase[_, _] = sink.getInput.asInstanceOf[ReduceOperatorBase[_, _]]
assertEquals(initialData.set.getType, reducer.getOperatorInfo.getInputType)
assertEquals(initialData.set.getType, reducer.getOperatorInfo.getOutputType)
assertEquals(initialData.javaSet.getType, reducer.getOperatorInfo.getInputType)
assertEquals(initialData.javaSet.getType, reducer.getOperatorInfo.getOutputType)
assertTrue(reducer.getDegreeOfParallelism == DOP || reducer.getDegreeOfParallelism == -1)
assertArrayEquals(Array[Int](2), reducer.getKeyColumns(0))
assertTrue(reducer.getInput.isInstanceOf[GenericDataSourceBase[_, _]])
......@@ -116,12 +116,12 @@ class ReduceTranslationTest {
val keyValueInfo = new TupleTypeInfo(
BasicTypeInfo.STRING_TYPE_INFO,
createTypeInformation[(Double, String, Long)])
assertEquals(initialData.set.getType, keyExtractor.getOperatorInfo.getInputType)
assertEquals(initialData.javaSet.getType, keyExtractor.getOperatorInfo.getInputType)
assertEquals(keyValueInfo, keyExtractor.getOperatorInfo.getOutputType)
assertEquals(keyValueInfo, reducer.getOperatorInfo.getInputType)
assertEquals(keyValueInfo, reducer.getOperatorInfo.getOutputType)
assertEquals(keyValueInfo, keyProjector.getOperatorInfo.getInputType)
assertEquals(initialData.set.getType, keyProjector.getOperatorInfo.getOutputType)
assertEquals(initialData.javaSet.getType, keyProjector.getOperatorInfo.getOutputType)
assertEquals(
classOf[KeyExtractingMapper[_, _]],
keyExtractor.getUserCodeWrapper.getUserCodeClass)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册