提交 f97f9ba9 编写于 作者: F Fabian Hueske

[FLINK-1060] Added methods to DataSet to explicitly hash-partition or...

[FLINK-1060] Added methods to DataSet to explicitly hash-partition or rebalance the input of Map-based operators.

This closes #108
上级 68801f2d
......@@ -46,6 +46,7 @@ import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
import org.apache.flink.api.common.operators.base.JoinOperatorBase;
import org.apache.flink.api.common.operators.base.MapOperatorBase;
import org.apache.flink.api.common.operators.base.MapPartitionOperatorBase;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase;
import org.apache.flink.api.common.operators.base.ReduceOperatorBase;
import org.apache.flink.api.common.operators.base.BulkIterationBase.PartialSolutionPlaceHolder;
import org.apache.flink.api.common.operators.base.DeltaIterationBase.SolutionSetPlaceHolder;
......@@ -69,6 +70,7 @@ import org.apache.flink.compiler.dag.MapPartitionNode;
import org.apache.flink.compiler.dag.MatchNode;
import org.apache.flink.compiler.dag.OptimizerNode;
import org.apache.flink.compiler.dag.PactConnection;
import org.apache.flink.compiler.dag.PartitionNode;
import org.apache.flink.compiler.dag.ReduceNode;
import org.apache.flink.compiler.dag.SinkJoiner;
import org.apache.flink.compiler.dag.SolutionSetNode;
......@@ -708,6 +710,9 @@ public class PactCompiler {
else if (c instanceof Union){
n = new BinaryUnionNode((Union<?>) c);
}
else if (c instanceof PartitionOperatorBase) {
n = new PartitionNode((PartitionOperatorBase<?>) c);
}
else if (c instanceof PartialSolutionPlaceHolder) {
if (this.parent == null) {
throw new InvalidProgramException("It is currently not supported to create data sinks inside iterations.");
......
......@@ -113,6 +113,9 @@ public abstract class CostEstimator {
case BROADCAST:
addBroadcastCost(channel, channel.getReplicationFactor(), costs);
break;
case PARTITION_FORCED_REBALANCE:
addRandomPartitioningCost(channel, costs);
break;
default:
throw new CompilerException("Unknown shipping strategy for input: " + channel.getShipStrategy());
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.compiler.dag;
import java.util.Collections;
import java.util.List;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod;
import org.apache.flink.api.common.operators.util.FieldSet;
import org.apache.flink.compiler.DataStatistics;
import org.apache.flink.compiler.dataproperties.GlobalProperties;
import org.apache.flink.compiler.dataproperties.LocalProperties;
import org.apache.flink.compiler.dataproperties.RequestedGlobalProperties;
import org.apache.flink.compiler.dataproperties.RequestedLocalProperties;
import org.apache.flink.compiler.operators.OperatorDescriptorSingle;
import org.apache.flink.compiler.plan.Channel;
import org.apache.flink.compiler.plan.SingleInputPlanNode;
import org.apache.flink.runtime.operators.DriverStrategy;
/**
* The optimizer's internal representation of a <i>Partition</i> operator node.
*/
public class PartitionNode extends SingleInputNode {
public PartitionNode(PartitionOperatorBase<?> operator) {
super(operator);
}
@Override
public PartitionOperatorBase<?> getPactContract() {
return (PartitionOperatorBase<?>) super.getPactContract();
}
@Override
public String getName() {
return "Partition";
}
@Override
protected List<OperatorDescriptorSingle> getPossibleProperties() {
return Collections.<OperatorDescriptorSingle>singletonList(new PartitionDescriptor(this.getPactContract().getPartitionMethod(), this.keys));
}
@Override
protected void computeOperatorSpecificDefaultEstimates(DataStatistics statistics) {
// partitioning does not change the number of records
this.estimatedNumRecords = getPredecessorNode().getEstimatedNumRecords();
}
@Override
public boolean isFieldConstant(int input, int fieldNumber) {
// Partition does not change any data
return true;
}
public static class PartitionDescriptor extends OperatorDescriptorSingle {
private final PartitionMethod pMethod;
private final FieldSet pKeys;
public PartitionDescriptor(PartitionMethod pMethod, FieldSet pKeys) {
this.pMethod = pMethod;
this.pKeys = pKeys;
}
@Override
public DriverStrategy getStrategy() {
return DriverStrategy.UNARY_NO_OP;
}
@Override
public SingleInputPlanNode instantiate(Channel in, SingleInputNode node) {
return new SingleInputPlanNode(node, "Partition", in, DriverStrategy.UNARY_NO_OP);
}
@Override
protected List<RequestedGlobalProperties> createPossibleGlobalProperties() {
RequestedGlobalProperties rgps = new RequestedGlobalProperties();
switch (this.pMethod) {
case HASH:
rgps.setHashPartitioned(pKeys.toFieldList());
break;
case REBALANCE:
rgps.setForceRebalancing();
break;
case RANGE:
throw new UnsupportedOperationException("Not yet supported");
default:
throw new IllegalArgumentException("Invalid partition method");
}
return Collections.singletonList(rgps);
}
@Override
protected List<RequestedLocalProperties> createPossibleLocalProperties() {
// partitioning does not require any local property.
return Collections.singletonList(new RequestedLocalProperties());
}
@Override
public GlobalProperties computeGlobalProperties(GlobalProperties gProps) {
// the partition node is a no-operation operation, such that all global properties are preserved.
return gProps;
}
@Override
public LocalProperties computeLocalProperties(LocalProperties lProps) {
// the partition node is a no-operation operation, such that all global properties are preserved.
return lProps;
}
}
}
......@@ -97,6 +97,12 @@ public class GlobalProperties implements Cloneable {
this.ordering = null;
}
public void setForcedRebalanced() {
this.partitioning = PartitioningProperty.FORCED_REBALANCED;
this.partitioningFields = null;
this.ordering = null;
}
public void addUniqueFieldCombination(FieldSet fields) {
if (this.uniqueFieldCombinations == null) {
this.uniqueFieldCombinations = new HashSet<FieldSet>();
......
......@@ -47,8 +47,13 @@ public enum PartitioningProperty {
/**
* Constant indicating full replication of the data to each parallel instance.
*/
FULL_REPLICATION;
FULL_REPLICATION,
/**
* Constant indicating a forced even rebalancing.
*/
FORCED_REBALANCED;
/**
* Checks, if this property represents in fact a partitioning. That is,
* whether this property is not equal to <tt>PartitionProperty.FULL_REPLICATION</tt>.
......@@ -57,7 +62,7 @@ public enum PartitioningProperty {
* false otherwise.
*/
public boolean isPartitioned() {
return this != FULL_REPLICATION;
return this != FULL_REPLICATION && this != FORCED_REBALANCED;
}
/**
......
......@@ -106,6 +106,12 @@ public final class RequestedGlobalProperties implements Cloneable {
this.ordering = null;
}
public void setForceRebalancing() {
this.partitioning = PartitioningProperty.FORCED_REBALANCED;
this.partitioningFields = null;
this.ordering = null;
}
/**
* Gets the partitioning property.
*
......@@ -211,6 +217,8 @@ public final class RequestedGlobalProperties implements Cloneable {
} else if (this.partitioning == PartitioningProperty.RANGE_PARTITIONED) {
return props.getPartitioning() == PartitioningProperty.RANGE_PARTITIONED &&
props.matchesOrderedPartitioning(this.ordering);
} else if (this.partitioning == PartitioningProperty.FORCED_REBALANCED) {
return props.getPartitioning() == PartitioningProperty.FORCED_REBALANCED;
} else {
throw new CompilerException("Bug in properties matching logic.");
}
......@@ -253,6 +261,9 @@ public final class RequestedGlobalProperties implements Cloneable {
channel.setDataDistribution(this.dataDistribution);
}
break;
case FORCED_REBALANCED:
channel.setShipStrategy(ShipStrategyType.PARTITION_FORCED_REBALANCE);
break;
default:
throw new CompilerException();
}
......
......@@ -67,13 +67,37 @@ public abstract class OperatorDescriptorSingle implements AbstractOperatorDescri
return this.localProps;
}
/**
* Returns a list of global properties that are required by this operator descriptor.
*
* @return A list of global properties that are required by this operator descriptor.
*/
protected abstract List<RequestedGlobalProperties> createPossibleGlobalProperties();
/**
* Returns a list of local properties that are required by this operator descriptor.
*
* @return A list of local properties that are required by this operator descriptor.
*/
protected abstract List<RequestedLocalProperties> createPossibleLocalProperties();
public abstract SingleInputPlanNode instantiate(Channel in, SingleInputNode node);
/**
* Returns the global properties which are present after the operator was applied on the
* provided global properties.
*
* @param in The global properties on which the operator is applied.
* @return The global properties which are valid after the operator has been applied.
*/
public abstract GlobalProperties computeGlobalProperties(GlobalProperties in);
/**
* Returns the local properties which are present after the operator was applied on the
* provided local properties.
*
* @param in The local properties on which the operator is applied.
* @return The local properties which are valid after the operator has been applied.
*/
public abstract LocalProperties computeLocalProperties(LocalProperties in);
}
......@@ -378,6 +378,9 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
case PARTITION_RANDOM:
this.globalProps.reset();
break;
case PARTITION_FORCED_REBALANCE:
this.globalProps.setForcedRebalanced();
break;
case NONE:
throw new CompilerException("Cannot produce GlobalProperties before ship strategy is set.");
}
......@@ -410,6 +413,7 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
case PARTITION_HASH:
case PARTITION_RANGE:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
this.localProps = new LocalProperties();
break;
case FORWARD:
......@@ -417,6 +421,8 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
break;
case NONE:
throw new CompilerException("ShipStrategy has not yet been set.");
default:
throw new CompilerException("Unknown ShipStrategy.");
}
}
......
......@@ -327,6 +327,9 @@ public class PlanJSONDumpGenerator {
case PARTITION_RANDOM:
shipStrategy = "Redistribute";
break;
case PARTITION_FORCED_REBALANCE:
shipStrategy = "Rebalance";
break;
default:
throw new CompilerException("Unknown ship strategy '" + conn.getShipStrategy().name()
+ "' in JSON generator.");
......
......@@ -1023,6 +1023,7 @@ public class NepheleJobGraphGenerator implements Visitor<PlanNode> {
case BROADCAST:
case PARTITION_HASH:
case PARTITION_RANGE:
case PARTITION_FORCED_REBALANCE:
distributionPattern = DistributionPattern.BIPARTITE;
break;
default:
......
......@@ -18,6 +18,7 @@
package org.apache.flink.compiler.util;
import org.apache.flink.api.common.functions.util.NoOpFunction;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.RecordOperator;
......
......@@ -18,6 +18,7 @@
package org.apache.flink.compiler.util;
import org.apache.flink.api.common.functions.util.NoOpFunction;
import org.apache.flink.api.common.operators.RecordOperator;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
......
......@@ -16,7 +16,7 @@
* limitations under the License.
*/
package org.apache.flink.compiler.util;
package org.apache.flink.api.common.functions.util;
import org.apache.flink.api.common.functions.AbstractRichFunction;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.functions.util.NoOpFunction;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
/**
*
* @param <IN> The input and result type.
*/
public class PartitionOperatorBase<IN> extends SingleInputOperator<IN, IN, NoOpFunction> {
private final PartitionMethod partitionMethod;
public PartitionOperatorBase(UnaryOperatorInformation<IN, IN> operatorInfo, PartitionMethod pMethod, int[] keys, String name) {
super(new UserCodeObjectWrapper<NoOpFunction>(new NoOpFunction()), operatorInfo, keys, name);
this.partitionMethod = pMethod;
}
public PartitionOperatorBase(UnaryOperatorInformation<IN, IN> operatorInfo, PartitionMethod pMethod, String name) {
super(new UserCodeObjectWrapper<NoOpFunction>(new NoOpFunction()), operatorInfo, name);
this.partitionMethod = pMethod;
}
public PartitionMethod getPartitionMethod() {
return this.partitionMethod;
}
public static enum PartitionMethod {
REBALANCE,
HASH,
RANGE;
}
}
......@@ -30,6 +30,7 @@ import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.io.FileOutputFormat;
import org.apache.flink.api.common.io.OutputFormat;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.functions.FormattingMapper;
import org.apache.flink.api.java.functions.KeySelector;
......@@ -56,6 +57,7 @@ import org.apache.flink.api.java.operators.JoinOperator.JoinOperatorSets;
import org.apache.flink.api.java.operators.Keys;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.PartitionedDataSet;
import org.apache.flink.api.java.operators.ProjectOperator.Projection;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.ReduceOperator;
......@@ -845,6 +847,49 @@ public abstract class DataSet<T> {
return new UnionOperator<T>(this, other);
}
// --------------------------------------------------------------------------------------------
// Partitioning
// --------------------------------------------------------------------------------------------
/**
* Hash-partitions a DataSet on the specified key fields.
* <p>
* <b>Important:</b>This operation shuffles the whole DataSet over the network and can take significant amount of time.
*
* @param fields The field indexes on which the DataSet is hash-partitioned.
* @return The partitioned DataSet.
*/
public PartitionedDataSet<T> partitionByHash(int... fields) {
return new PartitionedDataSet<T>(this, PartitionMethod.HASH, new Keys.FieldPositionKeys<T>(fields, getType(), false));
}
/**
* Partitions a DataSet using the specified KeySelector.
* <p>
* <b>Important:</b>This operation shuffles the whole DataSet over the network and can take significant amount of time.
*
* @param keyExtractor The KeyExtractor with which the DataSet is hash-partitioned.
* @return The partitioned DataSet.
*
* @see KeySelector
*/
public <K extends Comparable<K>> PartitionedDataSet<T> partitionByHash(KeySelector<T, K> keyExtractor) {
final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, type);
return new PartitionedDataSet<T>(this, PartitionMethod.HASH, new Keys.SelectorFunctionKeys<T, K>(keyExtractor, this.getType(), keyType));
}
/**
* Enforces a rebalancing of the DataSet, i.e., the DataSet is evenly distributed over all parallel instances of the
* following task. This can help to improve performance in case of heavy data skew and compute intensive operations.
* <p>
* <b>Important:</b>This operation shuffles the whole DataSet over the network and can take significant amount of time.
*
* @return The rebalanced DataSet.
*/
public PartitionedDataSet<T> rebalance() {
return new PartitionedDataSet<T>(this, PartitionMethod.REBALANCE);
}
// --------------------------------------------------------------------------------------------
// Top-K
// --------------------------------------------------------------------------------------------
......
......@@ -22,7 +22,6 @@ import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.java.operators.translation.PlanFilterOperator;
import org.apache.flink.api.java.DataSet;
/**
......@@ -35,7 +34,8 @@ public class FilterOperator<T> extends SingleInputUdfOperator<T, T, FilterOperat
protected final FilterFunction<T> function;
protected PartitionedDataSet<T> partitionedDataSet;
public FilterOperator(DataSet<T> input, FilterFunction<T> function) {
super(input, input.getType());
......@@ -43,9 +43,19 @@ public class FilterOperator<T> extends SingleInputUdfOperator<T, T, FilterOperat
extractSemanticAnnotationsFromUdf(function.getClass());
}
public FilterOperator(PartitionedDataSet<T> input, FilterFunction<T> function) {
this(input.getDataSet(), function);
this.partitionedDataSet = input;
}
@Override
protected org.apache.flink.api.common.operators.base.FilterOperatorBase<T, FlatMapFunction<T,T>> translateToDataFlow(Operator<T> input) {
// inject partition operator if necessary
if(this.partitionedDataSet != null) {
input = this.partitionedDataSet.translateToDataFlow(input, this.getParallelism());
}
String name = getName() != null ? getName() : function.getClass().getName();
// create operator
PlanFilterOperator<T> po = new PlanFilterOperator<T>(function, name, getInputType());
......
......@@ -36,6 +36,7 @@ public class FlatMapOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT, Fl
protected final FlatMapFunction<IN, OUT> function;
protected PartitionedDataSet<IN> partitionedDataSet;
public FlatMapOperator(DataSet<IN> input, TypeInformation<OUT> resultType, FlatMapFunction<IN, OUT> function) {
super(input, resultType);
......@@ -43,10 +44,20 @@ public class FlatMapOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT, Fl
this.function = function;
extractSemanticAnnotationsFromUdf(function.getClass());
}
public FlatMapOperator(PartitionedDataSet<IN> input, TypeInformation<OUT> resultType, FlatMapFunction<IN, OUT> function) {
this(input.getDataSet(), resultType, function);
this.partitionedDataSet = input;
}
@Override
protected org.apache.flink.api.common.operators.base.FlatMapOperatorBase<IN, OUT, FlatMapFunction<IN,OUT>> translateToDataFlow(Operator<IN> input) {
// inject partition operator if necessary
if(this.partitionedDataSet != null) {
input = this.partitionedDataSet.translateToDataFlow(input, this.getParallelism());
}
String name = getName() != null ? getName() : function.getClass().getName();
// create operator
FlatMapOperatorBase<IN, OUT, FlatMapFunction<IN, OUT>> po = new FlatMapOperatorBase<IN, OUT, FlatMapFunction<IN, OUT>>(function, new UnaryOperatorInformation<IN, OUT>(getInputType(), getResultType()), name);
......
......@@ -38,6 +38,8 @@ public class MapOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT, MapOpe
protected final MapFunction<IN, OUT> function;
protected PartitionedDataSet<IN> partitionedDataSet;
public MapOperator(DataSet<IN> input, TypeInformation<OUT> resultType, MapFunction<IN, OUT> function) {
......@@ -46,10 +48,20 @@ public class MapOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT, MapOpe
this.function = function;
extractSemanticAnnotationsFromUdf(function.getClass());
}
public MapOperator(PartitionedDataSet<IN> input, TypeInformation<OUT> resultType, MapFunction<IN, OUT> function) {
this(input.getDataSet(), resultType, function);
this.partitionedDataSet = input;
}
@Override
protected org.apache.flink.api.common.operators.base.MapOperatorBase<IN, OUT, MapFunction<IN, OUT>> translateToDataFlow(Operator<IN> input) {
// inject partition operator if necessary
if(this.partitionedDataSet != null) {
input = this.partitionedDataSet.translateToDataFlow(input, this.getParallelism());
}
String name = getName() != null ? getName() : function.getClass().getName();
// create operator
MapOperatorBase<IN, OUT, MapFunction<IN, OUT>> po = new MapOperatorBase<IN, OUT, MapFunction<IN, OUT>>(function, new UnaryOperatorInformation<IN, OUT>(getInputType(), getResultType()), name);
......
......@@ -38,6 +38,7 @@ public class MapPartitionOperator<IN, OUT> extends SingleInputUdfOperator<IN, OU
protected final MapPartitionFunction<IN, OUT> function;
protected PartitionedDataSet<IN> partitionedDataSet;
public MapPartitionOperator(DataSet<IN> input, TypeInformation<OUT> resultType, MapPartitionFunction<IN, OUT> function) {
super(input, resultType);
......@@ -45,10 +46,20 @@ public class MapPartitionOperator<IN, OUT> extends SingleInputUdfOperator<IN, OU
this.function = function;
extractSemanticAnnotationsFromUdf(function.getClass());
}
public MapPartitionOperator(PartitionedDataSet<IN> input, TypeInformation<OUT> resultType, MapPartitionFunction<IN, OUT> function) {
this(input.getDataSet(), resultType, function);
this.partitionedDataSet = input;
}
@Override
protected MapPartitionOperatorBase<IN, OUT, MapPartitionFunction<IN, OUT>> translateToDataFlow(Operator<IN> input) {
// inject partition operator if necessary
if(this.partitionedDataSet != null) {
input = this.partitionedDataSet.translateToDataFlow(input, this.getParallelism());
}
String name = getName() != null ? getName() : function.getClass().getName();
// create operator
MapPartitionOperatorBase<IN, OUT, MapPartitionFunction<IN, OUT>> po = new MapPartitionOperatorBase<IN, OUT, MapPartitionFunction<IN, OUT>>(function, new UnaryOperatorInformation<IN, OUT>(getInputType(), getResultType()), name);
......
......@@ -67,6 +67,7 @@ public class OperatorTranslation {
private <T> Operator<T> translate(DataSet<T> dataSet) {
// check if we have already translated that data set (operation or source)
Operator<?> previous = (Operator<?>) this.translated.get(dataSet);
if (previous != null) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.operators;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.MapOperatorBase;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.UnsupportedLambdaExpressionException;
import org.apache.flink.api.java.operators.translation.KeyExtractingMapper;
import org.apache.flink.api.java.operators.translation.KeyRemovingMapper;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
public class PartitionedDataSet<IN> {
private final DataSet<IN> dataSet;
private final Keys<IN> pKeys;
private final PartitionMethod pMethod;
public PartitionedDataSet(DataSet<IN> input, PartitionMethod pMethod, Keys<IN> pKeys) {
this.dataSet = input;
if(pMethod == PartitionMethod.HASH && pKeys == null) {
throw new IllegalArgumentException("Hash Partitioning requires keys");
} else if(pMethod == PartitionMethod.RANGE) {
throw new UnsupportedOperationException("Range Partitioning not yet supported");
}
if(pKeys instanceof Keys.FieldPositionKeys<?> && !input.getType().isTupleType()) {
throw new IllegalArgumentException("Hash Partitioning with key fields only possible on Tuple DataSets");
}
this.pMethod = pMethod;
this.pKeys = pKeys;
}
public PartitionedDataSet(DataSet<IN> input, PartitionMethod pMethod) {
this(input, pMethod, null);
}
public DataSet<IN> getDataSet() {
return this.dataSet;
}
/**
* Applies a Map transformation on a {@link DataSet}.<br/>
* The transformation calls a {@link org.apache.flink.api.java.functions.RichMapFunction} for each element of the DataSet.
* Each MapFunction call returns exactly one element.
*
* @param mapper The MapFunction that is called for each element of the DataSet.
* @return A MapOperator that represents the transformed DataSet.
*
* @see org.apache.flink.api.java.functions.RichMapFunction
* @see MapOperator
* @see DataSet
*/
public <R> MapOperator<IN, R> map(MapFunction<IN, R> mapper) {
if (mapper == null) {
throw new NullPointerException("Map function must not be null.");
}
if (FunctionUtils.isLambdaFunction(mapper)) {
throw new UnsupportedLambdaExpressionException();
}
final TypeInformation<R> resultType = TypeExtractor.getMapReturnTypes(mapper, dataSet.getType());
return new MapOperator<IN, R>(this, resultType, mapper);
}
/**
* Applies a Map-style operation to the entire partition of the data.
* The function is called once per parallel partition of the data,
* and the entire partition is available through the given Iterator.
* The number of elements that each instance of the MapPartition function
* sees is non deterministic and depends on the degree of parallelism of the operation.
*
* This function is intended for operations that cannot transform individual elements,
* requires no grouping of elements. To transform individual elements,
* the use of {@code map()} and {@code flatMap()} is preferable.
*
* @param mapPartition The MapPartitionFunction that is called for the full DataSet.
* @return A MapPartitionOperator that represents the transformed DataSet.
*
* @see MapPartitionFunction
* @see MapPartitionOperator
* @see DataSet
*/
public <R> MapPartitionOperator<IN, R> mapPartition(MapPartitionFunction<IN, R> mapPartition ){
if (mapPartition == null) {
throw new NullPointerException("MapPartition function must not be null.");
}
final TypeInformation<R> resultType = TypeExtractor.getMapPartitionReturnTypes(mapPartition, dataSet.getType());
return new MapPartitionOperator<IN, R>(this, resultType, mapPartition);
}
/**
* Applies a FlatMap transformation on a {@link DataSet}.<br/>
* The transformation calls a {@link org.apache.flink.api.java.functions.RichFlatMapFunction} for each element of the DataSet.
* Each FlatMapFunction call can return any number of elements including none.
*
* @param flatMapper The FlatMapFunction that is called for each element of the DataSet.
* @return A FlatMapOperator that represents the transformed DataSet.
*
* @see org.apache.flink.api.java.functions.RichFlatMapFunction
* @see FlatMapOperator
* @see DataSet
*/
public <R> FlatMapOperator<IN, R> flatMap(FlatMapFunction<IN, R> flatMapper) {
if (flatMapper == null) {
throw new NullPointerException("FlatMap function must not be null.");
}
if (FunctionUtils.isLambdaFunction(flatMapper)) {
throw new UnsupportedLambdaExpressionException();
}
TypeInformation<R> resultType = TypeExtractor.getFlatMapReturnTypes(flatMapper, dataSet.getType());
return new FlatMapOperator<IN, R>(this, resultType, flatMapper);
}
/**
* Applies a Filter transformation on a {@link DataSet}.<br/>
* The transformation calls a {@link org.apache.flink.api.java.functions.RichFilterFunction} for each element of the DataSet
* and retains only those element for which the function returns true. Elements for
* which the function returns false are filtered.
*
* @param filter The FilterFunction that is called for each element of the DataSet.
* @return A FilterOperator that represents the filtered DataSet.
*
* @see org.apache.flink.api.java.functions.RichFilterFunction
* @see FilterOperator
* @see DataSet
*/
public FilterOperator<IN> filter(FilterFunction<IN> filter) {
if (filter == null) {
throw new NullPointerException("Filter function must not be null.");
}
return new FilterOperator<IN>(this, filter);
}
/*
* Translation of partitioning
*/
protected org.apache.flink.api.common.operators.SingleInputOperator<?, IN, ?> translateToDataFlow(Operator<IN> input, int partitionDop) {
String name = "Partition";
// distinguish between partition types
if (pMethod == PartitionMethod.REBALANCE) {
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<IN, IN>(dataSet.getType(), dataSet.getType());
PartitionOperatorBase<IN> noop = new PartitionOperatorBase<IN>(operatorInfo, pMethod, name);
// set input
noop.setInput(input);
// set DOP
noop.setDegreeOfParallelism(partitionDop);
return noop;
}
else if (pMethod == PartitionMethod.HASH) {
if (pKeys instanceof Keys.FieldPositionKeys) {
int[] logicalKeyPositions = pKeys.computeLogicalKeyPositions();
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<IN, IN>(dataSet.getType(), dataSet.getType());
PartitionOperatorBase<IN> noop = new PartitionOperatorBase<IN>(operatorInfo, pMethod, logicalKeyPositions, name);
// set input
noop.setInput(input);
// set DOP
noop.setDegreeOfParallelism(partitionDop);
return noop;
} else if (pKeys instanceof Keys.SelectorFunctionKeys) {
@SuppressWarnings("unchecked")
Keys.SelectorFunctionKeys<IN, ?> selectorKeys = (Keys.SelectorFunctionKeys<IN, ?>) pKeys;
MapOperatorBase<?, IN, ?> po = translateSelectorFunctionReducer(selectorKeys, pMethod, dataSet.getType(), name, input, partitionDop);
return po;
}
else {
throw new UnsupportedOperationException("Unrecognized key type.");
}
}
else if (pMethod == PartitionMethod.RANGE) {
throw new UnsupportedOperationException("Range partitioning not yet supported");
}
return null;
}
// --------------------------------------------------------------------------------------------
private static <T, K> MapOperatorBase<Tuple2<K, T>, T, ?> translateSelectorFunctionReducer(Keys.SelectorFunctionKeys<T, ?> rawKeys,
PartitionMethod pMethod, TypeInformation<T> inputType, String name, Operator<T> input, int partitionDop)
{
@SuppressWarnings("unchecked")
final Keys.SelectorFunctionKeys<T, K> keys = (Keys.SelectorFunctionKeys<T, K>) rawKeys;
TypeInformation<Tuple2<K, T>> typeInfoWithKey = new TupleTypeInfo<Tuple2<K, T>>(keys.getKeyType(), inputType);
UnaryOperatorInformation<Tuple2<K, T>, Tuple2<K, T>> operatorInfo = new UnaryOperatorInformation<Tuple2<K, T>, Tuple2<K, T>>(typeInfoWithKey, typeInfoWithKey);
KeyExtractingMapper<T, K> extractor = new KeyExtractingMapper<T, K>(keys.getKeyExtractor());
MapOperatorBase<T, Tuple2<K, T>, MapFunction<T, Tuple2<K, T>>> keyExtractingMap = new MapOperatorBase<T, Tuple2<K, T>, MapFunction<T, Tuple2<K, T>>>(extractor, new UnaryOperatorInformation<T, Tuple2<K, T>>(inputType, typeInfoWithKey), "Key Extractor");
PartitionOperatorBase<Tuple2<K, T>> noop = new PartitionOperatorBase<Tuple2<K, T>>(operatorInfo, pMethod, new int[]{0}, name);
MapOperatorBase<Tuple2<K, T>, T, MapFunction<Tuple2<K, T>, T>> keyRemovingMap = new MapOperatorBase<Tuple2<K, T>, T, MapFunction<Tuple2<K, T>, T>>(new KeyRemovingMapper<T, K>(), new UnaryOperatorInformation<Tuple2<K, T>, T>(typeInfoWithKey, inputType), "Key Extractor");
keyExtractingMap.setInput(input);
noop.setInput(keyExtractingMap);
keyRemovingMap.setInput(noop);
// set dop
keyExtractingMap.setDegreeOfParallelism(input.getDegreeOfParallelism());
noop.setDegreeOfParallelism(partitionDop);
keyRemovingMap.setDegreeOfParallelism(partitionDop);
return keyRemovingMap;
}
}
......@@ -86,6 +86,7 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
case PARTITION_HASH:
case PARTITION_RANGE:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
case BROADCAST:
break;
default:
......@@ -106,6 +107,7 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
switch (strategy) {
case FORWARD:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
return robin(numberOfChannels);
case PARTITION_HASH:
return hashPartitionDefault(record.getInstance(), numberOfChannels);
......
......@@ -50,6 +50,11 @@ public enum ShipStrategyType {
*/
PARTITION_RANGE(true, true),
/**
* Partitioning the data evenly
*/
PARTITION_FORCED_REBALANCE(true, false),
/**
* Replicating the data set to all instances.
*/
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.javaApiOperators;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
import org.apache.flink.test.util.JavaProgramTestBase;
import org.apache.flink.util.Collector;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
@RunWith(Parameterized.class)
public class PartitionITCase extends JavaProgramTestBase {
private static int NUM_PROGRAMS = 4;
private int curProgId = config.getInteger("ProgramId", -1);
private String resultPath;
private String expectedResult;
public PartitionITCase(Configuration config) {
super(config);
}
@Override
protected void preSubmit() throws Exception {
resultPath = getTempDirPath("result");
}
@Override
protected void testProgram() throws Exception {
expectedResult = PartitionProgs.runProgram(curProgId, resultPath);
}
@Override
protected void postSubmit() throws Exception {
compareResultsByLinesInMemory(expectedResult, resultPath);
}
@Parameters
public static Collection<Object[]> getConfigurations() throws FileNotFoundException, IOException {
LinkedList<Configuration> tConfigs = new LinkedList<Configuration>();
for(int i=1; i <= NUM_PROGRAMS; i++) {
Configuration config = new Configuration();
config.setInteger("ProgramId", i);
tConfigs.add(config);
}
return toParameterList(tConfigs);
}
private static class PartitionProgs {
public static String runProgram(int progId, String resultPath) throws Exception {
switch(progId) {
case 1: {
/*
* Test hash partition by key field
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
DataSet<Long> uniqLongs = ds
.partitionByHash(1)
.mapPartition(new UniqueLongMapper());
uniqLongs.writeAsText(resultPath);
env.execute();
// return expected result
return "1\n" +
"2\n" +
"3\n" +
"4\n" +
"5\n" +
"6\n";
}
case 2: {
/*
* Test hash partition by key selector
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
DataSet<Long> uniqLongs = ds
.partitionByHash(new KeySelector<Tuple3<Integer,Long,String>, Long>() {
private static final long serialVersionUID = 1L;
@Override
public Long getKey(Tuple3<Integer, Long, String> value) throws Exception {
return value.f1;
}
})
.mapPartition(new UniqueLongMapper());
uniqLongs.writeAsText(resultPath);
env.execute();
// return expected result
return "1\n" +
"2\n" +
"3\n" +
"4\n" +
"5\n" +
"6\n";
}
case 3: {
/*
* Test forced rebalancing
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
// generate some number in parallel
DataSet<Long> ds = env.generateSequence(1,3000);
DataSet<Tuple2<Integer, Integer>> uniqLongs = ds
// introduce some partition skew by filtering
.filter(new FilterFunction<Long>() {
private static final long serialVersionUID = 1L;
@Override
public boolean filter(Long value) throws Exception {
if (value <= 780) {
return false;
} else {
return true;
}
}
})
// rebalance
.rebalance()
// count values in each partition
.map(new PartitionIndexMapper())
.groupBy(0)
.reduce(new ReduceFunction<Tuple2<Integer, Integer>>() {
private static final long serialVersionUID = 1L;
public Tuple2<Integer, Integer> reduce(Tuple2<Integer, Integer> v1, Tuple2<Integer, Integer> v2) {
return new Tuple2<Integer, Integer>(v1.f0, v1.f1+v2.f1);
}
})
// round counts to mitigate runtime scheduling effects (lazy split assignment)
.map(new MapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>(){
private static final long serialVersionUID = 1L;
@Override
public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) throws Exception {
value.f1 = (value.f1 / 10);
return value;
}
});
uniqLongs.writeAsText(resultPath);
env.execute();
// return expected result
return "(0,55)\n" +
"(1,55)\n" +
"(2,55)\n" +
"(3,55)\n";
}
case 4: {
/*
* Test hash partition by key field and different DOP
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(3);
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
DataSet<Long> uniqLongs = ds
.partitionByHash(1)
.mapPartition(new UniqueLongMapper()).setParallelism(4);
uniqLongs.writeAsText(resultPath);
env.execute();
// return expected result
return "1\n" +
"2\n" +
"3\n" +
"4\n" +
"5\n" +
"6\n";
}
default:
throw new IllegalArgumentException("Invalid program id");
}
}
}
public static class UniqueLongMapper implements MapPartitionFunction<Tuple3<Integer,Long,String>, Long> {
private static final long serialVersionUID = 1L;
@Override
public void mapPartition(Iterable<Tuple3<Integer, Long, String>> records, Collector<Long> out) throws Exception {
HashSet<Long> uniq = new HashSet<Long>();
for(Tuple3<Integer,Long,String> t : records) {
uniq.add(t.f1);
}
for(Long l : uniq) {
out.collect(l);
}
}
}
public static class PartitionIndexMapper extends RichMapFunction<Long, Tuple2<Integer, Integer>> {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<Integer, Integer> map(Long value) throws Exception {
return new Tuple2<Integer, Integer>(this.getRuntimeContext().getIndexOfThisSubtask(), 1);
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册