提交 ac42d150 编写于 作者: Z zentol

[FLINK-2439] [py] Expand DataSet feature coverage

上级 ab847071
......@@ -473,7 +473,7 @@ public class PythonPlanBinder {
private void createDistinctOperation(PythonOperationInfo info) throws IOException {
DataSet op = (DataSet) sets.get(info.parentID);
sets.put(info.setID, info.keys.length == 0 ? op.distinct() : op.distinct(info.keys).name("Distinct").map(new KeyDiscarder()).name("DistinctPostStep"));
sets.put(info.setID, op.distinct(info.keys).name("Distinct").map(new KeyDiscarder()).name("DistinctPostStep"));
}
private void createFirstOperation(PythonOperationInfo info) throws IOException {
......
......@@ -14,16 +14,15 @@ package org.apache.flink.python.api.functions.util;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
/*
Utility function to extract the value from a Key-Value Tuple.
*/
@ForwardedFields("f1->*")
public class KeyDiscarder implements MapFunction<Tuple2<Tuple, byte[]>, byte[]> {
public class KeyDiscarder <T> implements MapFunction<Tuple2<T, byte[]>, byte[]> {
@Override
public byte[] map(Tuple2<Tuple, byte[]> value) throws Exception {
public byte[] map(Tuple2<T, byte[]> value) throws Exception {
return value.f1;
}
}
......@@ -45,6 +45,10 @@ class _Identifier(object):
SINK_TEXT = "sink_text"
SINK_PRINT = "sink_print"
BROADCAST = "broadcast"
FIRST = "first"
DISTINCT = "distinct"
PARTITION_HASH = "partition_hash"
REBALANCE = "rebalance"
class WriteMode(object):
......
......@@ -277,6 +277,36 @@ class DataSet(object):
self._env._sets.append(child)
return child_set
def distinct(self, *fields):
"""
Returns a distinct set of a tuple DataSet using field position keys.
:param fields: One or more field positions on which the distinction of the DataSet is decided.
:return: The distinct DataSet.
"""
f = None
if len(fields) == 0:
f = lambda x: (x,)
fields = (0,)
if isinstance(fields[0], TYPES.FunctionType):
f = lambda x: (fields[0](x),)
if isinstance(fields[0], KeySelectorFunction):
f = lambda x: (fields[0].get_key(x),)
if f is None:
f = lambda x: tuple([x[key] for key in fields])
return self.map(lambda x: (f(x), x)).name("DistinctPreStep")._distinct(tuple([x for x in range(len(fields))]))
def _distinct(self, fields):
self._info.types = _createKeyValueTypeInfo(len(fields))
child = OperationInfo()
child_set = DataSet(self._env, child)
child.identifier = _Identifier.DISTINCT
child.parent = self._info
child.keys = fields
self._info.children.append(child)
self._env._sets.append(child)
return child_set
def filter(self, operator):
"""
Applies a Filter transformation on a DataSet.
......@@ -302,6 +332,22 @@ class DataSet(object):
self._env._sets.append(child)
return child_set
def first(self, count):
"""
Returns a new set containing the first n elements in this DataSet.
:param count: The desired number of elements.
:return: A DataSet containing the elements.
"""
child = OperationInfo()
child_set = DataSet(self._env, child)
child.identifier = _Identifier.FIRST
child.parent = self._info
child.count = count
self._info.children.append(child)
self._env._sets.append(child)
return child_set
def flat_map(self, operator):
"""
Applies a FlatMap transformation on a DataSet.
......@@ -426,6 +472,52 @@ class DataSet(object):
self._env._sets.append(child)
return child_set
def partition_by_hash(self, *fields):
f = None
if len(fields) == 0:
f = lambda x: (x,)
if isinstance(fields[0], TYPES.FunctionType):
f = lambda x: (fields[0](x),)
if isinstance(fields[0], KeySelectorFunction):
f = lambda x: (fields[0].get_key(x),)
if f is None:
f = lambda x: tuple([x[key] for key in fields])
return self.map(lambda x: (f(x), x)).name("HashPartitionPreStep")._partition_by_hash(tuple([x for x in range(len(fields))]))
def _partition_by_hash(self, fields):
"""
Hash-partitions a DataSet on the specified key fields.
Important:This operation shuffles the whole DataSet over the network and can take significant amount of time.
:param fields: The field indexes on which the DataSet is hash-partitioned.
:return: The partitioned DataSet.
"""
self._info.types = _createKeyValueTypeInfo(len(fields))
child = OperationInfo()
child_set = DataSet(self._env, child)
child.identifier = _Identifier.PARTITION_HASH
child.parent = self._info
child.keys = fields
self._info.children.append(child)
self._env._sets.append(child)
return child_set
def rebalance(self):
"""
Enforces a re-balancing of the DataSet, i.e., the DataSet is evenly distributed over all parallel instances of the
following task. This can help to improve performance in case of heavy data skew and compute intensive operations.
Important:This operation shuffles the whole DataSet over the network and can take significant amount of time.
:return: The re-balanced DataSet.
"""
child = OperationInfo()
child_set = DataSet(self._env, child)
child.identifier = _Identifier.REBALANCE
child.parent = self._info
self._info.children.append(child)
self._env._sets.append(child)
return child_set
def union(self, other_set):
"""
Creates a union of this DataSet with an other DataSet.
......@@ -445,6 +537,10 @@ class DataSet(object):
self._env._sets.append(child)
return child_set
def name(self, name):
self._info.name = name
return self
class OperatorSet(DataSet):
def __init__(self, env, info):
......@@ -472,6 +568,23 @@ class Grouping(object):
def _finalize(self):
pass
def first(self, count):
"""
Returns a new set containing the first n elements in this DataSet.
:param count: The desired number of elements.
:return: A DataSet containing the elements.
"""
self._finalize()
child = OperationInfo()
child_set = DataSet(self._env, child)
child.identifier = _Identifier.FIRST
child.parent = self._info
child.count = count
self._info.children.append(child)
self._env._sets.append(child)
return child_set
def reduce_group(self, operator, combinable=False):
"""
Applies a GroupReduce transformation.
......
......@@ -281,6 +281,14 @@ class Environment(object):
collect(set.id)
collect(set.parent.id)
for case in Switch(identifier):
if case(_Identifier.REBALANCE):
break
if case(_Identifier.DISTINCT, _Identifier.PARTITION_HASH):
collect(set.keys)
break
if case(_Identifier.FIRST):
collect(set.count)
break
if case(_Identifier.SORT):
collect(set.field)
collect(set.order)
......
......@@ -41,6 +41,7 @@ class OperationInfo():
self.sinks = []
self.children = []
self.path = None
self.count = 0
self.values = []
self.projections = []
self.bcvars = []
......
......@@ -142,6 +142,16 @@ if __name__ == "__main__":
d2 \
.union(d4) \
.map_partition(Verify2([(1, 0.5, "hello", True), (2, 0.4, "world", False), (1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)], "Union")).output()
d1 \
.first(1) \
.map_partition(Verify([1], "First")).output()
d1 \
.rebalance()
d6 \
.distinct() \
.map_partition(Verify([1, 12], "Distinct")).output()
d2 \
.partition_by_hash(3)
#Execution
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册