提交 30647a2e 编写于 作者: Z zentol

[FLINK-2432] Custom serializer support

This closes #962
上级 946e8f64
......@@ -34,6 +34,7 @@ import static org.apache.flink.python.api.streaming.Sender.TYPE_NULL;
import static org.apache.flink.python.api.streaming.Sender.TYPE_SHORT;
import static org.apache.flink.python.api.streaming.Sender.TYPE_STRING;
import static org.apache.flink.python.api.streaming.Sender.TYPE_TUPLE;
import org.apache.flink.python.api.types.CustomTypeWrapper;
import org.apache.flink.util.Collector;
/**
......@@ -192,7 +193,7 @@ public class Receiver implements Serializable {
case TYPE_NULL:
return null;
default:
throw new IllegalArgumentException("Unknown TypeID encountered: " + type);
return new CustomTypeDeserializer(type).deserialize();
}
}
......@@ -245,14 +246,29 @@ public class Receiver implements Serializable {
case TYPE_NULL:
return new NullDeserializer();
default:
throw new IllegalArgumentException("Unknown TypeID encountered: " + type);
return new CustomTypeDeserializer(type);
}
}
private interface Deserializer<T> {
public T deserialize();
}
private class CustomTypeDeserializer implements Deserializer<CustomTypeWrapper> {
private final byte type;
public CustomTypeDeserializer(byte type) {
this.type = type;
}
@Override
public CustomTypeWrapper deserialize() {
int size = fileBuffer.getInt();
byte[] data = new byte[size];
fileBuffer.get(data);
return new CustomTypeWrapper(type, data);
}
}
private class BooleanDeserializer implements Deserializer<Boolean> {
......
......@@ -27,6 +27,7 @@ import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.java.tuple.Tuple;
import static org.apache.flink.python.api.PythonPlanBinder.FLINK_TMP_DATA_DIR;
import static org.apache.flink.python.api.PythonPlanBinder.MAPPED_FILE_SIZE;
import org.apache.flink.python.api.types.CustomTypeWrapper;
/**
* General-purpose class to write data to memory-mapped files.
......@@ -180,7 +181,7 @@ public class Sender implements Serializable {
}
private enum SupportedTypes {
TUPLE, BOOLEAN, BYTE, BYTES, CHARACTER, SHORT, INTEGER, LONG, FLOAT, DOUBLE, STRING, OTHER, NULL
TUPLE, BOOLEAN, BYTE, BYTES, CHARACTER, SHORT, INTEGER, LONG, FLOAT, DOUBLE, STRING, OTHER, NULL, CUSTOMTYPEWRAPPER
}
//=====Serializer===================================================================================================
......@@ -231,6 +232,9 @@ public class Sender implements Serializable {
case NULL:
fileBuffer.put(TYPE_NULL);
return new NullSerializer();
case CUSTOMTYPEWRAPPER:
fileBuffer.put(((CustomTypeWrapper) value).getType());
return new CustomTypeSerializer();
default:
throw new IllegalArgumentException("Unknown Type encountered: " + type);
}
......@@ -253,6 +257,18 @@ public class Sender implements Serializable {
public abstract void serializeInternal(T value);
}
private class CustomTypeSerializer extends Serializer<CustomTypeWrapper> {
public CustomTypeSerializer() {
super(0);
}
@Override
public void serializeInternal(CustomTypeWrapper value) {
byte[] bytes = value.getData();
buffer = ByteBuffer.wrap(bytes);
buffer.position(bytes.length);
}
}
private class ByteSerializer extends Serializer<Byte> {
public ByteSerializer() {
super(1);
......
/**
* Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
* file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package org.apache.flink.python.api.types;
/**
* Container for serialized python objects, generally assumed to be custom objects.
*/
public class CustomTypeWrapper {
private final byte typeID;
private final byte[] data;
public CustomTypeWrapper(byte typeID, byte[] data) {
this.typeID = typeID;
this.data = data;
}
public byte getType() {
return typeID;
}
public byte[] getData() {
return data;
}
}
......@@ -19,6 +19,7 @@ from struct import pack
import sys
from flink.connection.Constants import Types
from flink.plan.Constants import _Dummy
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
......@@ -30,15 +31,16 @@ else:
class Collector(object):
def __init__(self, con):
def __init__(self, con, env):
self._connection = con
self._serializer = None
self._env = env
def _close(self):
self._connection.send_end_signal()
def collect(self, value):
self._serializer = _get_serializer(self._connection.write, value)
self._serializer = _get_serializer(self._connection.write, value, self._env._types)
self.collect = self._collect
self.collect(value)
......@@ -46,11 +48,11 @@ class Collector(object):
self._connection.write(self._serializer.serialize(value))
def _get_serializer(write, value):
def _get_serializer(write, value, custom_types):
if isinstance(value, (list, tuple)):
write(Types.TYPE_TUPLE)
write(pack(">I", len(value)))
return TupleSerializer(write, value)
return TupleSerializer(write, value, custom_types)
elif value is None:
write(Types.TYPE_NULL)
return NullSerializer()
......@@ -70,12 +72,25 @@ def _get_serializer(write, value):
write(Types.TYPE_DOUBLE)
return FloatSerializer()
else:
for entry in custom_types:
if isinstance(value, entry[1]):
write(entry[0])
return CustomTypeSerializer(entry[2])
raise Exception("Unsupported Type encountered.")
class CustomTypeSerializer(object):
def __init__(self, serializer):
self._serializer = serializer
def serialize(self, value):
msg = self._serializer.serialize(value)
return b"".join([pack(">i",len(msg)), msg])
class TupleSerializer(object):
def __init__(self, write, value):
self.serializer = [_get_serializer(write, field) for field in value]
def __init__(self, write, value, custom_types):
self.serializer = [_get_serializer(write, field, custom_types) for field in value]
def serialize(self, value):
bits = []
......@@ -117,8 +132,9 @@ class NullSerializer(object):
class TypedCollector(object):
def __init__(self, con):
def __init__(self, con, env):
self._connection = con
self._env = env
def collect(self, value):
if not isinstance(value, (list, tuple)):
......@@ -153,5 +169,13 @@ class TypedCollector(object):
value = bytes(value)
size = pack(">I", len(value))
self._connection.write(b"".join([Types.TYPE_BYTES, size, value]))
elif isinstance(value, _Dummy):
self._connection.write(pack(">i", 127)[3:])
self._connection.write(pack(">i", 0))
else:
for entry in self._env._types:
if isinstance(value, entry[1]):
self._connection.write(entry[0])
self._connection.write(CustomTypeSerializer(entry[2]).serialize(value))
return
raise Exception("Unsupported Type encountered.")
\ No newline at end of file
......@@ -168,21 +168,25 @@ class CoGroupIterator(object):
class Iterator(defIter.Iterator):
def __init__(self, con, group=0):
def __init__(self, con, env, group=0):
super(Iterator, self).__init__()
self._connection = con
self._init = True
self._group = group
self._deserializer = None
self._env = env
def __next__(self):
return self.next()
def _read(self, des_size):
return self._connection.read(des_size, self._group)
def next(self):
if self.has_next():
if self._deserializer is None:
self._deserializer = _get_deserializer(self._group, self._connection.read)
return self._deserializer.deserialize()
self._deserializer = _get_deserializer(self._group, self._connection.read, self._env._types)
return self._deserializer.deserialize(self._read)
else:
raise StopIteration
......@@ -207,121 +211,88 @@ class DummyIterator(Iterator):
return False
def _get_deserializer(group, read, type=None):
def _get_deserializer(group, read, custom_types, type=None):
if type is None:
type = read(1, group)
return _get_deserializer(group, read, type)
return _get_deserializer(group, read, custom_types, type)
elif type == Types.TYPE_TUPLE:
return TupleDeserializer(read, group)
return TupleDeserializer(read, group, custom_types)
elif type == Types.TYPE_BYTE:
return ByteDeserializer(read, group)
return ByteDeserializer()
elif type == Types.TYPE_BYTES:
return ByteArrayDeserializer(read, group)
return ByteArrayDeserializer()
elif type == Types.TYPE_BOOLEAN:
return BooleanDeserializer(read, group)
return BooleanDeserializer()
elif type == Types.TYPE_FLOAT:
return FloatDeserializer(read, group)
return FloatDeserializer()
elif type == Types.TYPE_DOUBLE:
return DoubleDeserializer(read, group)
return DoubleDeserializer()
elif type == Types.TYPE_INTEGER:
return IntegerDeserializer(read, group)
return IntegerDeserializer()
elif type == Types.TYPE_LONG:
return LongDeserializer(read, group)
return LongDeserializer()
elif type == Types.TYPE_STRING:
return StringDeserializer(read, group)
return StringDeserializer()
elif type == Types.TYPE_NULL:
return NullDeserializer(read, group)
return NullDeserializer()
else:
for entry in custom_types:
if type == entry[0]:
return entry[3]
raise Exception("Unable to find deserializer for type ID " + str(type))
class TupleDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
size = unpack(">I", self.read(4, self._group))[0]
self.deserializer = [_get_deserializer(self._group, self.read) for _ in range(size)]
def __init__(self, read, group, custom_types):
size = unpack(">I", read(4, group))[0]
self.deserializer = [_get_deserializer(group, read, custom_types) for _ in range(size)]
def deserialize(self):
return tuple([s.deserialize() for s in self.deserializer])
def deserialize(self, read):
return tuple([s.deserialize(read) for s in self.deserializer])
class ByteDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
def deserialize(self):
return unpack(">c", self.read(1, self._group))[0]
def deserialize(self, read):
return unpack(">c", read(1))[0]
class ByteArrayDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
def deserialize(self):
size = unpack(">i", self.read(4, self._group))[0]
return bytearray(self.read(size, self._group)) if size else bytearray(b"")
def deserialize(self, read):
size = unpack(">i", read(4))[0]
return bytearray(read(size)) if size else bytearray(b"")
class BooleanDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
def deserialize(self):
return unpack(">?", self.read(1, self._group))[0]
def deserialize(self, read):
return unpack(">?", read(1))[0]
class FloatDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
def deserialize(self):
return unpack(">f", self.read(4, self._group))[0]
def deserialize(self, read):
return unpack(">f", read(4))[0]
class DoubleDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
def deserialize(self):
return unpack(">d", self.read(8, self._group))[0]
def deserialize(self, read):
return unpack(">d", read(8))[0]
class IntegerDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
def deserialize(self):
return unpack(">i", self.read(4, self._group))[0]
def deserialize(self, read):
return unpack(">i", read(4))[0]
class LongDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
def deserialize(self):
return unpack(">q", self.read(8, self._group))[0]
def deserialize(self, read):
return unpack(">q", read(8))[0]
class StringDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
def deserialize(self):
length = unpack(">i", self.read(4, self._group))[0]
return self.read(length, self._group).decode("utf-8") if length else ""
def deserialize(self, read):
length = unpack(">i", read(4))[0]
return read(length).decode("utf-8") if length else ""
class NullDeserializer(object):
def __init__(self, read, group):
self.read = read
self._group = group
def deserialize(self):
return None
......@@ -25,13 +25,13 @@ class CoGroupFunction(Function.Function):
self._keys1 = None
self._keys2 = None
def _configure(self, input_file, output_file, port):
def _configure(self, input_file, output_file, port, env):
self._connection = Connection.TwinBufferingTCPMappedFileConnection(input_file, output_file, port)
self._iterator = Iterator.Iterator(self._connection, 0)
self._iterator2 = Iterator.Iterator(self._connection, 1)
self._iterator = Iterator.Iterator(self._connection, env, 0)
self._iterator2 = Iterator.Iterator(self._connection, env, 1)
self._cgiter = Iterator.CoGroupIterator(self._iterator, self._iterator2, self._keys1, self._keys2)
self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
self._configure_chain(Collector.Collector(self._connection))
self._configure_chain(Collector.Collector(self._connection, env))
def _run(self):
collector = self._collector
......
......@@ -32,11 +32,11 @@ class Function(object):
self.context = None
self._chain_operator = None
def _configure(self, input_file, output_file, port):
def _configure(self, input_file, output_file, port, env):
self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port)
self._iterator = Iterator.Iterator(self._connection)
self._iterator = Iterator.Iterator(self._connection, env)
self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
self._configure_chain(Collector.Collector(self._connection))
self._configure_chain(Collector.Collector(self._connection, env))
def _configure_chain(self, collector):
if self._chain_operator is not None:
......
......@@ -29,19 +29,19 @@ class GroupReduceFunction(Function.Function):
self._combine = False
self._values = []
def _configure(self, input_file, output_file, port):
def _configure(self, input_file, output_file, port, env):
if self._combine:
self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port)
self._iterator = Iterator.Iterator(self._connection)
self._collector = Collector.Collector(self._connection)
self._iterator = Iterator.Iterator(self._connection, env)
self._collector = Collector.Collector(self._connection, env)
self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
self._run = self._run_combine
else:
self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port)
self._iterator = Iterator.Iterator(self._connection)
self._iterator = Iterator.Iterator(self._connection, env)
self._group_iterator = Iterator.GroupIterator(self._iterator, self._keys)
self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
self._configure_chain(Collector.Collector(self._connection))
self._configure_chain(Collector.Collector(self._connection, env))
self._open()
def _open(self):
......
......@@ -27,21 +27,21 @@ class ReduceFunction(Function.Function):
self._combine = False
self._values = []
def _configure(self, input_file, output_file, port):
def _configure(self, input_file, output_file, port, env):
if self._combine:
self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port)
self._iterator = Iterator.Iterator(self._connection)
self._collector = Collector.Collector(self._connection)
self._iterator = Iterator.Iterator(self._connection, env)
self._collector = Collector.Collector(self._connection, env)
self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
self._run = self._run_combine
else:
self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port)
self._iterator = Iterator.Iterator(self._connection)
self._iterator = Iterator.Iterator(self._connection, env)
if self._keys is None:
self._run = self._run_allreduce
else:
self._group_iterator = Iterator.GroupIterator(self._iterator, self._keys)
self._configure_chain(Collector.Collector(self._connection))
self._configure_chain(Collector.Collector(self._connection, env))
self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
def _set_grouping_keys(self, keys):
......
......@@ -91,6 +91,11 @@ import sys
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
class _Dummy(object):
pass
if PY2:
BOOL = True
INT = 1
......@@ -98,9 +103,11 @@ if PY2:
FLOAT = 2.5
STRING = "type"
BYTES = bytearray(b"byte")
CUSTOM = _Dummy()
elif PY3:
BOOL = True
INT = 1
FLOAT = 2.5
STRING = "type"
BYTES = bytearray(b"byte")
CUSTOM = _Dummy()
......@@ -22,7 +22,7 @@ from flink.plan.Constants import _Fields, _Identifier
from flink.utilities import Switch
import copy
import sys
from struct import pack
def get_environment():
"""
......@@ -49,6 +49,19 @@ class Environment(object):
#specials
self._broadcast = []
self._types = []
def register_type(self, type, serializer, deserializer):
"""
Registers the given type with this environment, allowing all operators within to
(de-)serialize objects of the given type.
:param type: class of the objects to be (de-)serialized
:param serializer: instance of the serializer
:param deserializer: instance of the deserializer
"""
self._types.append((pack(">i",126 - len(self._types))[3:], type, serializer, deserializer))
def read_csv(self, path, types, line_delimiter="\n", field_delimiter=','):
"""
Create a DataSet that represents the tuples produced by reading the given CSV file.
......@@ -127,7 +140,7 @@ class Environment(object):
if plan_mode:
output_path = sys.stdin.readline().rstrip('\n')
self._connection = Connection.OneWayBusyBufferingMappedFileConnection(output_path)
self._collector = Collector.TypedCollector(self._connection)
self._collector = Collector.TypedCollector(self._connection, self)
self._send_plan()
self._connection._write_buffer()
else:
......@@ -146,7 +159,7 @@ class Environment(object):
operator = set[_Fields.OPERATOR]
if set[_Fields.ID] == -id:
operator = set[_Fields.COMBINEOP]
operator._configure(input_path, output_path, port)
operator._configure(input_path, output_path, port, self)
operator._go()
sys.stdout.flush()
sys.stderr.flush()
......@@ -342,4 +355,4 @@ class Environment(object):
collect(_Identifier.BROADCAST)
collect(entry[_Fields.PARENT][_Fields.ID])
collect(entry[_Fields.OTHER][_Fields.ID])
collect(entry[_Fields.NAME])
\ No newline at end of file
collect(entry[_Fields.NAME])
......@@ -25,7 +25,8 @@ from flink.functions.CrossFunction import CrossFunction
from flink.functions.JoinFunction import JoinFunction
from flink.functions.GroupReduceFunction import GroupReduceFunction
from flink.functions.CoGroupFunction import CoGroupFunction
from flink.plan.Constants import INT, STRING, FLOAT, BOOL, Order
from flink.plan.Constants import INT, STRING, FLOAT, BOOL, CUSTOM, Order
import struct
class Mapper(MapFunction):
......@@ -259,6 +260,31 @@ if __name__ == "__main__":
.co_group(d5).where(0).equal_to(2).using(CoGroup(), ((INT, FLOAT, STRING, BOOL), (FLOAT, FLOAT, INT))) \
.map_partition(Verify([((1, 0.5, "hello", True), (4.4, 4.3, 1)), ((1, 0.4, "hello", False), (4.3, 4.4, 1))], "CoGroup"), STRING).output()
#Custom Serialization
class Ext(MapPartitionFunction):
def map_partition(self, iterator, collector):
for value in iterator:
collector.collect(value.value)
class MyObj(object):
def __init__(self, i):
self.value = i
class MySerializer(object):
def serialize(self, value):
return struct.pack(">i", value.value)
class MyDeserializer(object):
def deserialize(self, read):
i = struct.unpack(">i", read(4))[0]
return MyObj(i)
env.register_type(MyObj, MySerializer(), MyDeserializer())
env.from_elements(MyObj(2), MyObj(4)) \
.map(Id(), CUSTOM).map_partition(Ext(), INT) \
.map_partition(Verify([2, 4], "CustomTypeSerialization"), STRING).output()
env.set_degree_of_parallelism(1)
env.execute(local=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册