提交 c1dd9604 编写于 作者: Z zentol

[FLINK-1945][py] Python Tests less verbose

This closes #1376
上级 40cb4e0d
......@@ -77,5 +77,11 @@ under the License.
<artifactId>flink-clients</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-test-utils</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
......@@ -20,26 +20,16 @@ import org.apache.flink.core.fs.FileSystem;
import org.apache.flink.core.fs.Path;
import static org.apache.flink.python.api.PythonPlanBinder.ARGUMENT_PYTHON_2;
import static org.apache.flink.python.api.PythonPlanBinder.ARGUMENT_PYTHON_3;
import org.junit.Test;
import org.junit.BeforeClass;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.flink.test.util.JavaProgramTestBase;
public class PythonPlanBinderTest {
private static final Logger LOG = LoggerFactory.getLogger(PythonPlanBinder.class);
private static boolean python2Supported = true;
private static boolean python3Supported = true;
private static List<String> TEST_FILES;
@BeforeClass
public static void setup() throws Exception {
findTestFiles();
checkPythonSupport();
public class PythonPlanBinderTest extends JavaProgramTestBase {
@Override
protected boolean skipCollectionExecution() {
return true;
}
private static void findTestFiles() throws Exception {
TEST_FILES = new ArrayList();
private static List<String> findTestFiles() throws Exception {
List<String> files = new ArrayList();
FileSystem fs = FileSystem.getLocalFileSystem();
FileStatus[] status = fs.listStatus(
new Path(fs.getWorkingDirectory().toString()
......@@ -47,41 +37,39 @@ public class PythonPlanBinderTest {
for (FileStatus f : status) {
String file = f.getPath().toString();
if (file.endsWith(".py")) {
TEST_FILES.add(file);
files.add(file);
}
}
return files;
}
private static void checkPythonSupport() {
private static boolean isPython2Supported() {
try {
Runtime.getRuntime().exec("python");
return true;
} catch (IOException ex) {
python2Supported = false;
LOG.info("No Python 2 runtime detected.");
return false;
}
}
private static boolean isPython3Supported() {
try {
Runtime.getRuntime().exec("python3");
return true;
} catch (IOException ex) {
python3Supported = false;
LOG.info("No Python 3 runtime detected.");
return false;
}
}
@Test
public void testPython2() throws Exception {
if (python2Supported) {
for (String file : TEST_FILES) {
LOG.info("testing " + file);
@Override
protected void testProgram() throws Exception {
if (isPython2Supported()) {
for (String file : findTestFiles()) {
PythonPlanBinder.main(new String[]{ARGUMENT_PYTHON_2, file});
}
}
}
@Test
public void testPython3() throws Exception {
if (python3Supported) {
for (String file : TEST_FILES) {
LOG.info("testing " + file);
if (isPython3Supported()) {
for (String file : findTestFiles()) {
PythonPlanBinder.main(new String[]{ARGUMENT_PYTHON_3, file});
}
}
......
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from flink.plan.Environment import get_environment
from flink.plan.Constants import INT, STRING
from flink.plan.Constants import WriteMode
if __name__ == "__main__":
env = get_environment()
d1 = env.read_csv("src/test/python/org/apache/flink/python/api/data_csv", (INT, INT, STRING))
d1.write_csv("/tmp/flink/result", line_delimiter="\n", field_delimiter="|", write_mode=WriteMode.OVERWRITE)
env.set_degree_of_parallelism(1)
env.execute(local=True)
......@@ -21,107 +21,11 @@ from flink.functions.FlatMapFunction import FlatMapFunction
from flink.functions.FilterFunction import FilterFunction
from flink.functions.MapPartitionFunction import MapPartitionFunction
from flink.functions.ReduceFunction import ReduceFunction
from flink.functions.CrossFunction import CrossFunction
from flink.functions.JoinFunction import JoinFunction
from flink.functions.GroupReduceFunction import GroupReduceFunction
from flink.functions.CoGroupFunction import CoGroupFunction
from flink.plan.Constants import INT, STRING, FLOAT, BOOL, CUSTOM, Order
from flink.plan.Constants import INT, STRING, FLOAT, BOOL, BYTES, CUSTOM, Order, WriteMode
import struct
class Mapper(MapFunction):
def map(self, value):
return value * value
class Filter(FilterFunction):
def __init__(self, limit):
super(Filter, self).__init__()
self.limit = limit
def filter(self, value):
return value > self.limit
class FlatMap(FlatMapFunction):
def flat_map(self, value, collector):
collector.collect(value)
collector.collect(value * 2)
class MapPartition(MapPartitionFunction):
def map_partition(self, iterator, collector):
for value in iterator:
collector.collect(value * 2)
class Reduce(ReduceFunction):
def reduce(self, value1, value2):
return value1 + value2
class Reduce2(ReduceFunction):
def reduce(self, value1, value2):
return (value1[0] + value2[0], value1[1] + value2[1], value1[2], value1[3] or value2[3])
class Cross(CrossFunction):
def cross(self, value1, value2):
return (value1, value2[3])
class MapperBcv(MapFunction):
def map(self, value):
factor = self.context.get_broadcast_variable("test")[0][0]
return value * factor
class Join(JoinFunction):
def join(self, value1, value2):
if value1[3]:
return value2[0] + str(value1[0])
else:
return value2[0] + str(value1[1])
class GroupReduce(GroupReduceFunction):
def reduce(self, iterator, collector):
if iterator.has_next():
i, f, s, b = iterator.next()
for value in iterator:
i += value[0]
f += value[1]
b |= value[3]
collector.collect((i, f, s, b))
class GroupReduce2(GroupReduceFunction):
def reduce(self, iterator, collector):
for value in iterator:
collector.collect(value)
class GroupReduce3(GroupReduceFunction):
def reduce(self, iterator, collector):
collector.collect(iterator.next())
def combine(self, iterator, collector):
if iterator.has_next():
v1 = iterator.next()
if iterator.has_next():
v2 = iterator.next()
if v1[0] < v2[0]:
collector.collect(v1)
else:
collector.collect(v2)
class CoGroup(CoGroupFunction):
def co_group(self, iterator1, iterator2, collector):
while iterator1.has_next() and iterator2.has_next():
collector.collect((iterator1.next(), iterator2.next()))
#Utilities
class Id(MapFunction):
def map(self, value):
return value
......@@ -137,10 +41,9 @@ class Verify(MapPartitionFunction):
index = 0
for value in iterator:
if value != self.expected[index]:
print(self.name + " Test failed. Expected: " + str(self.expected[index]) + " Actual: " + str(value))
raise Exception(self.name + " failed!")
raise Exception(self.name + " Test failed. Expected: " + str(self.expected[index]) + " Actual: " + str(value))
index += 1
collector.collect(self.name + " successful!")
#collector.collect(self.name + " successful!")
class Verify2(MapPartitionFunction):
......@@ -155,8 +58,8 @@ class Verify2(MapPartitionFunction):
try:
self.expected.remove(value)
except Exception:
raise Exception(self.name + " failed!")
collector.collect(self.name + " successful!")
raise Exception(self.name + " failed! Actual value " + str(value) + "not contained in expected values: "+str(self.expected))
#collector.collect(self.name + " successful!")
if __name__ == "__main__":
......@@ -172,93 +75,33 @@ if __name__ == "__main__":
d5 = env.from_elements((4.4, 4.3, 1), (4.3, 4.4, 1), (4.2, 4.1, 3), (4.1, 4.1, 3))
d1 \
.map((lambda x: x * x), INT).map(Mapper(), INT) \
.map_partition(Verify([1, 1296, 20736], "Map"), STRING).output()
d6 = env.from_elements(1, 1, 12)
d1 \
.map(Mapper(), INT).map((lambda x: x * x), INT) \
.map_partition(Verify([1, 1296, 20736], "Chained Lambda"), STRING).output()
#CSV Source/Sink
csv_data = env.read_csv("src/test/python/org/apache/flink/python/api/data_csv", (INT, INT, STRING))
d1 \
.filter(Filter(5)).filter(Filter(8)) \
.map_partition(Verify([12], "Filter"), STRING).output()
csv_data.write_csv("/tmp/flink/result1", line_delimiter="\n", field_delimiter="|", write_mode=WriteMode.OVERWRITE)
d1 \
.flat_map(FlatMap(), INT).flat_map(FlatMap(), INT) \
.map_partition(Verify([1, 2, 2, 4, 6, 12, 12, 24, 12, 24, 24, 48], "FlatMap"), STRING).output()
#Text Source/Sink
text_data = env.read_text("src/test/python/org/apache/flink/python/api/data_text")
d1 \
.map_partition(MapPartition(), INT) \
.map_partition(Verify([2, 12, 24], "MapPartition"), STRING).output()
text_data.write_text("/tmp/flink/result2", WriteMode.OVERWRITE)
d1 \
.reduce(Reduce()) \
.map_partition(Verify([19], "AllReduce"), STRING).output()
d4 \
.group_by(2).reduce(Reduce2()) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "CombineReduce"), STRING).output()
d4 \
.map(Id(), (INT, FLOAT, STRING, BOOL)).group_by(2).reduce(Reduce2()) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "ChainedReduce"), STRING).output()
d1 \
.map(MapperBcv(), INT).with_broadcast_set("test", d2) \
.map_partition(Verify([1, 6, 12], "Broadcast"), STRING).output()
d1 \
.cross(d2).using(Cross(), (INT, BOOL)) \
.map_partition(Verify([(1, True), (1, False), (6, True), (6, False), (12, True), (12, False)], "Cross"), STRING).output()
d1 \
.cross(d3) \
.map_partition(Verify([(1, ("hello",)), (1, ("world",)), (6, ("hello",)), (6, ("world",)), (12, ("hello",)), (12, ("world",))], "Default Cross"), STRING).output()
#Types
env.from_elements(bytearray(b"hello"), bytearray(b"world"))\
.map(Id(), BYTES).map_partition(Verify([bytearray(b"hello"), bytearray(b"world")], "Byte"), STRING).output()
d2 \
.cross(d3).project_second(0).project_first(0, 1) \
.map_partition(Verify([("hello", 1, 0.5), ("world", 1, 0.5), ("hello", 2, 0.4), ("world", 2, 0.4)], "Project Cross"), STRING).output()
env.from_elements(1, 2, 3, 4, 5)\
.map(Id(), INT).map_partition(Verify([1,2,3,4,5], "Int"), STRING).output()
d2 \
.join(d3).where(2).equal_to(0).using(Join(), STRING) \
.map_partition(Verify(["hello1", "world0.4"], "Join"), STRING).output()
env.from_elements(True, True, False)\
.map(Id(), BOOL).map_partition(Verify([True, True, False], "Bool"), STRING).output()
d2 \
.join(d3).where(2).equal_to(0).project_first(0, 3).project_second(0) \
.map_partition(Verify([(1, True, "hello"), (2, False, "world")], "Project Join"), STRING).output()
env.from_elements(1.4, 1.7, 12312.23)\
.map(Id(), FLOAT).map_partition(Verify([1.4, 1.7, 12312.23], "Float"), STRING).output()
d2 \
.join(d3).where(2).equal_to(0) \
.map_partition(Verify([((1, 0.5, "hello", True), ("hello",)), ((2, 0.4, "world", False), ("world",))], "Default Join"), STRING).output()
d2 \
.project(0, 1, 2) \
.map_partition(Verify([(1, 0.5, "hello"), (2, 0.4, "world")], "Project"), STRING).output()
d2 \
.union(d4) \
.map_partition(Verify2([(1, 0.5, "hello", True), (2, 0.4, "world", False), (1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)], "Union"), STRING).output()
d4 \
.group_by(2).reduce_group(GroupReduce(), (INT, FLOAT, STRING, BOOL), combinable=False) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "AllGroupReduce"), STRING).output()
d4 \
.map(Id(), (INT, FLOAT, STRING, BOOL)).group_by(2).reduce_group(GroupReduce(), (INT, FLOAT, STRING, BOOL), combinable=True) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "ChainedGroupReduce"), STRING).output()
d4 \
.group_by(2).reduce_group(GroupReduce(), (INT, FLOAT, STRING, BOOL), combinable=True) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "CombineGroupReduce"), STRING).output()
d5 \
.group_by(2).sort_group(0, Order.DESCENDING).sort_group(1, Order.ASCENDING).reduce_group(GroupReduce3(), (FLOAT, FLOAT, INT), combinable=True) \
.map_partition(Verify([(4.3, 4.4, 1), (4.1, 4.1, 3)], "ChainedSortedGroupReduce"), STRING).output()
d4 \
.co_group(d5).where(0).equal_to(2).using(CoGroup(), ((INT, FLOAT, STRING, BOOL), (FLOAT, FLOAT, INT))) \
.map_partition(Verify([((1, 0.5, "hello", True), (4.4, 4.3, 1)), ((1, 0.4, "hello", False), (4.3, 4.4, 1))], "CoGroup"), STRING).output()
env.from_elements("hello", "world")\
.map(Id(), STRING).map_partition(Verify(["hello", "world"], "String"), STRING).output()
#Custom Serialization
class Ext(MapPartitionFunction):
......@@ -285,6 +128,105 @@ if __name__ == "__main__":
.map(Id(), CUSTOM).map_partition(Ext(), INT) \
.map_partition(Verify([2, 4], "CustomTypeSerialization"), STRING).output()
#Map
class Mapper(MapFunction):
def map(self, value):
return value * value
d1 \
.map((lambda x: x * x), INT).map(Mapper(), INT) \
.map_partition(Verify([1, 1296, 20736], "Map"), STRING).output()
#FlatMap
class FlatMap(FlatMapFunction):
def flat_map(self, value, collector):
collector.collect(value)
collector.collect(value * 2)
d1 \
.flat_map(FlatMap(), INT).flat_map(FlatMap(), INT) \
.map_partition(Verify([1, 2, 2, 4, 6, 12, 12, 24, 12, 24, 24, 48], "FlatMap"), STRING).output()
#MapPartition
class MapPartition(MapPartitionFunction):
def map_partition(self, iterator, collector):
for value in iterator:
collector.collect(value * 2)
d1 \
.map_partition(MapPartition(), INT) \
.map_partition(Verify([2, 12, 24], "MapPartition"), STRING).output()
#Filter
class Filter(FilterFunction):
def __init__(self, limit):
super(Filter, self).__init__()
self.limit = limit
def filter(self, value):
return value > self.limit
d1 \
.filter(Filter(5)).filter(Filter(8)) \
.map_partition(Verify([12], "Filter"), STRING).output()
#Reduce
class Reduce(ReduceFunction):
def reduce(self, value1, value2):
return value1 + value2
class Reduce2(ReduceFunction):
def reduce(self, value1, value2):
return (value1[0] + value2[0], value1[1] + value2[1], value1[2], value1[3] or value2[3])
d1 \
.reduce(Reduce()) \
.map_partition(Verify([19], "AllReduce"), STRING).output()
d4 \
.group_by(2).reduce(Reduce2()) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "CombineReduce"), STRING).output()
d4 \
.map(Id(), (INT, FLOAT, STRING, BOOL)).group_by(2).reduce(Reduce2()) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "ChainedReduce"), STRING).output()
#GroupReduce
class GroupReduce(GroupReduceFunction):
def reduce(self, iterator, collector):
if iterator.has_next():
i, f, s, b = iterator.next()
for value in iterator:
i += value[0]
f += value[1]
b |= value[3]
collector.collect((i, f, s, b))
class GroupReduce2(GroupReduceFunction):
def reduce(self, iterator, collector):
for value in iterator:
collector.collect(value)
class GroupReduce3(GroupReduceFunction):
def reduce(self, iterator, collector):
collector.collect(iterator.next())
def combine(self, iterator, collector):
if iterator.has_next():
v1 = iterator.next()
if iterator.has_next():
v2 = iterator.next()
if v1[0] < v2[0]:
collector.collect(v1)
else:
collector.collect(v2)
d4 \
.group_by(2).reduce_group(GroupReduce(), (INT, FLOAT, STRING, BOOL), combinable=False) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "AllGroupReduce"), STRING).output()
d4 \
.map(Id(), (INT, FLOAT, STRING, BOOL)).group_by(2).reduce_group(GroupReduce(), (INT, FLOAT, STRING, BOOL), combinable=True) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "ChainedGroupReduce"), STRING).output()
d4 \
.group_by(2).reduce_group(GroupReduce(), (INT, FLOAT, STRING, BOOL), combinable=True) \
.map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "CombineGroupReduce"), STRING).output()
d5 \
.group_by(2).sort_group(0, Order.DESCENDING).sort_group(1, Order.ASCENDING).reduce_group(GroupReduce3(), (FLOAT, FLOAT, INT), combinable=True) \
.map_partition(Verify([(4.3, 4.4, 1), (4.1, 4.1, 3)], "ChainedSortedGroupReduce"), STRING).output()
#Execution
env.set_degree_of_parallelism(1)
env.execute(local=True)
# ###############################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from flink.plan.Environment import get_environment
from flink.functions.MapFunction import MapFunction
from flink.functions.MapPartitionFunction import MapPartitionFunction
from flink.functions.CrossFunction import CrossFunction
from flink.functions.JoinFunction import JoinFunction
from flink.functions.CoGroupFunction import CoGroupFunction
from flink.plan.Constants import BOOL, INT, FLOAT, STRING
#Utilities
class Id(MapFunction):
def map(self, value):
return value
class Verify(MapPartitionFunction):
def __init__(self, expected, name):
super(Verify, self).__init__()
self.expected = expected
self.name = name
def map_partition(self, iterator, collector):
index = 0
for value in iterator:
if value != self.expected[index]:
raise Exception(self.name + " Test failed. Expected: " + str(self.expected[index]) + " Actual: " + str(value))
index += 1
#collector.collect(self.name + " successful!")
class Verify2(MapPartitionFunction):
def __init__(self, expected, name):
super(Verify2, self).__init__()
self.expected = expected
self.name = name
def map_partition(self, iterator, collector):
for value in iterator:
if value in self.expected:
try:
self.expected.remove(value)
except Exception:
raise Exception(self.name + " failed! Actual value " + str(value) + "not contained in expected values: "+str(self.expected))
#collector.collect(self.name + " successful!")
if __name__ == "__main__":
env = get_environment()
d1 = env.from_elements(1, 6, 12)
d2 = env.from_elements((1, 0.5, "hello", True), (2, 0.4, "world", False))
d3 = env.from_elements(("hello",), ("world",))
d4 = env.from_elements((1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False))
d5 = env.from_elements((4.4, 4.3, 1), (4.3, 4.4, 1), (4.2, 4.1, 3), (4.1, 4.1, 3))
d6 = env.from_elements(1, 1, 12)
#Join
class Join(JoinFunction):
def join(self, value1, value2):
if value1[3]:
return value2[0] + str(value1[0])
else:
return value2[0] + str(value1[1])
d2 \
.join(d3).where(2).equal_to(0).using(Join(), STRING) \
.map_partition(Verify(["hello1", "world0.4"], "Join"), STRING).output()
d2 \
.join(d3).where(2).equal_to(0).project_first(0, 3).project_second(0) \
.map_partition(Verify([(1, True, "hello"), (2, False, "world")], "Project Join"), STRING).output()
d2 \
.join(d3).where(2).equal_to(0) \
.map_partition(Verify([((1, 0.5, "hello", True), ("hello",)), ((2, 0.4, "world", False), ("world",))], "Default Join"), STRING).output()
#Cross
class Cross(CrossFunction):
def cross(self, value1, value2):
return (value1, value2[3])
d1 \
.cross(d2).using(Cross(), (INT, BOOL)) \
.map_partition(Verify([(1, True), (1, False), (6, True), (6, False), (12, True), (12, False)], "Cross"), STRING).output()
d1 \
.cross(d3) \
.map_partition(Verify([(1, ("hello",)), (1, ("world",)), (6, ("hello",)), (6, ("world",)), (12, ("hello",)), (12, ("world",))], "Default Cross"), STRING).output()
d2 \
.cross(d3).project_second(0).project_first(0, 1) \
.map_partition(Verify([("hello", 1, 0.5), ("world", 1, 0.5), ("hello", 2, 0.4), ("world", 2, 0.4)], "Project Cross"), STRING).output()
#CoGroup
class CoGroup(CoGroupFunction):
def co_group(self, iterator1, iterator2, collector):
while iterator1.has_next() and iterator2.has_next():
collector.collect((iterator1.next(), iterator2.next()))
d4 \
.co_group(d5).where(0).equal_to(2).using(CoGroup(), ((INT, FLOAT, STRING, BOOL), (FLOAT, FLOAT, INT))) \
.map_partition(Verify([((1, 0.5, "hello", True), (4.4, 4.3, 1)), ((1, 0.4, "hello", False), (4.3, 4.4, 1))], "CoGroup"), STRING).output()
#Broadcast
class MapperBcv(MapFunction):
def map(self, value):
factor = self.context.get_broadcast_variable("test")[0][0]
return value * factor
d1 \
.map(MapperBcv(), INT).with_broadcast_set("test", d2) \
.map_partition(Verify([1, 6, 12], "Broadcast"), STRING).output()
#Misc
class Mapper(MapFunction):
def map(self, value):
return value * value
d1 \
.map(Mapper(), INT).map((lambda x: x * x), INT) \
.map_partition(Verify([1, 1296, 20736], "Chained Lambda"), STRING).output()
d2 \
.project(0, 1, 2) \
.map_partition(Verify([(1, 0.5, "hello"), (2, 0.4, "world")], "Project"), STRING).output()
d2 \
.union(d4) \
.map_partition(Verify2([(1, 0.5, "hello", True), (2, 0.4, "world", False), (1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)], "Union"), STRING).output()
#Execution
env.set_degree_of_parallelism(1)
env.execute(local=True)
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from flink.plan.Environment import get_environment
from flink.plan.Constants import WriteMode
if __name__ == "__main__":
env = get_environment()
d1 = env.read_text("src/test/python/org/apache/flink/python/api/data_text")
d1.write_text("/tmp/flink/result", WriteMode.OVERWRITE)
env.set_degree_of_parallelism(1)
env.execute(local=True)
......@@ -16,8 +16,20 @@
# limitations under the License.
################################################################################
from flink.plan.Environment import get_environment
from flink.plan.Constants import INT, STRING, BOOL, FLOAT
import sys
from flink.plan.Constants import BOOL, STRING
from flink.functions.MapPartitionFunction import MapPartitionFunction
class Verify(MapPartitionFunction):
def __init__(self, msg):
super(Verify, self).__init__()
self.msg = msg
def map_partition(self, iterator, collector):
if self.msg is None:
return
else:
raise Exception("Type Deduction failed: " + self.msg)
if __name__ == "__main__":
env = get_environment()
......@@ -28,35 +40,34 @@ if __name__ == "__main__":
direct_from_source = d1.filter(lambda x:True)
msg = None
if direct_from_source._info.types != ("hello", 4, 3.2, True):
sys.exit("Error deducting type directly from source.")
msg = "Error deducting type directly from source."
from_common_udf = d1.map(lambda x: x[3], BOOL).filter(lambda x:True)
if from_common_udf._info.types != BOOL:
sys.exit("Error deducting type from common udf.")
msg = "Error deducting type from common udf."
through_projection = d1.project(3, 2).filter(lambda x:True)
if through_projection._info.types != (True, 3.2):
sys.exit("Error deducting type through projection.")
msg = "Error deducting type through projection."
through_default_op = d1.cross(d2).filter(lambda x:True)
if through_default_op._info.types != (("hello", 4, 3.2, True), "world"):
sys.exit("Error deducting type through default J/C." +str(through_default_op._info.types))
msg = "Error deducting type through default J/C." +str(through_default_op._info.types)
through_prj_op = d1.cross(d2).project_first(1, 0).project_second().project_first(3, 2).filter(lambda x:True)
if through_prj_op._info.types != (4, "hello", "world", True, 3.2):
sys.exit("Error deducting type through projection J/C. "+str(through_prj_op._info.types))
msg = "Error deducting type through projection J/C. "+str(through_prj_op._info.types)
env = get_environment()
msg = env.from_elements("Type deduction test successful.")
msg.output()
env.execute()
env.from_elements("dummy").map_partition(Verify(msg), STRING).output()
env.execute(local=True)
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from flink.plan.Environment import get_environment
from flink.functions.MapFunction import MapFunction
from flink.functions.MapPartitionFunction import MapPartitionFunction
from flink.plan.Constants import BOOL, INT, FLOAT, STRING, BYTES
class Verify(MapPartitionFunction):
def __init__(self, expected, name):
super(Verify, self).__init__()
self.expected = expected
self.name = name
def map_partition(self, iterator, collector):
index = 0
for value in iterator:
if value != self.expected[index]:
print(self.name + " Test failed. Expected: " + str(self.expected[index]) + " Actual: " + str(value))
raise Exception(self.name + " failed!")
index += 1
collector.collect(self.name + " successful!")
class Id(MapFunction):
def map(self, value):
return value
if __name__ == "__main__":
env = get_environment()
d1 = env.from_elements(bytearray(b"hello"), bytearray(b"world"))
d1.map(Id(), BYTES).map_partition(Verify([bytearray(b"hello"), bytearray(b"world")], "Byte"), STRING).output()
d2 = env.from_elements(1,2,3,4,5)
d2.map(Id(), INT).map_partition(Verify([1,2,3,4,5], "Int"), STRING).output()
d3 = env.from_elements(True, True, False)
d3.map(Id(), BOOL).map_partition(Verify([True, True, False], "Bool"), STRING).output()
d4 = env.from_elements(1.4, 1.7, 12312.23)
d4.map(Id(), FLOAT).map_partition(Verify([1.4, 1.7, 12312.23], "Float"), STRING).output()
d5 = env.from_elements("hello", "world")
d5.map(Id(), STRING).map_partition(Verify(["hello", "world"], "String"), STRING).output()
env.set_degree_of_parallelism(1)
env.execute(local=True)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册