提交 3b8998d5 编写于 作者: H huangxingbo 提交者: Dian Fu

[FLINK-20642][python] Introduce InternalRow to optimize the performance of Python UDAF

This closes #14487.
上级 5ffcfcaf
......@@ -16,6 +16,7 @@
# limitations under the License.
################################################################################
# cython: language_level=3
from pyflink.fn_execution.coder_impl_fast cimport InternalRow, InternalRowKind
cdef class DistinctViewDescriptor:
cdef object input_extractor
......@@ -83,10 +84,10 @@ cdef class GroupAggFunctionBase:
cpdef void open(self, function_context)
cpdef void close(self)
cpdef list process_element(self, object input_data)
cpdef void on_timer(self, object key)
cdef bint is_retract_msg(self, object input_data)
cdef bint is_accumulate_msg(self, object input_data)
cpdef list process_element(self, InternalRow input_data)
cpdef void on_timer(self, InternalRow key)
cdef bint is_retract_msg(self, InternalRowKind row_kind)
cdef bint is_accumulate_msg(self, InternalRowKind row_kind)
cdef class GroupAggFunction(GroupAggFunctionBase):
pass
......
......@@ -24,14 +24,13 @@ from typing import List, Dict
from apache_beam.coders import PickleCoder, Coder
from pyflink.common import Row, RowKind
from pyflink.fn_execution.aggregate import DataViewSpec, ListViewSpec, MapViewSpec, \
StateDataViewStore
from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend
from pyflink.table import AggregateFunction, TableAggregateFunction
cdef object join_row(list left, list right):
return Row(*(left.__add__(right)))
cdef InternalRow join_row(list left, list right, InternalRowKind row_kind):
return InternalRow(left.__add__(right), row_kind)
cdef class DistinctViewDescriptor:
def __cinit__(self, input_extractor, filter_args):
......@@ -374,14 +373,15 @@ cdef class SimpleTableAggsHandleFunction(SimpleAggsHandleFunctionBase):
distinct_view_descriptors)
cdef list emit_value(self, list current_key, bint is_retract):
cdef InternalRow result
cdef list results
udf = self._udfs[0] # type: TableAggregateFunction
results = []
for x in udf.emit_value(self._accumulators[0]):
result = join_row(current_key, x._values)
if is_retract:
result.set_row_kind(RowKind.DELETE)
result = join_row(current_key, x._values, InternalRowKind.DELETE)
else:
result.set_row_kind(RowKind.INSERT)
result = join_row(current_key, x._values, InternalRowKind.INSERT)
results.append(result)
return results
......@@ -436,7 +436,7 @@ cdef class GroupAggFunctionBase:
cpdef void close(self):
self.aggs_handle.close()
cpdef void on_timer(self, object key):
cpdef void on_timer(self, InternalRow key):
if self.state_cleaning_enabled:
self.state_backend.set_current_key(key)
accumulator_state = self.state_backend.get_value_state(
......@@ -444,13 +444,13 @@ cdef class GroupAggFunctionBase:
accumulator_state.clear()
self.aggs_handle.cleanup()
cdef bint is_retract_msg(self, object data):
return data.get_row_kind() == RowKind.UPDATE_BEFORE or data.get_row_kind() == RowKind.DELETE
cdef bint is_retract_msg(self, InternalRowKind row_kind):
return row_kind == InternalRowKind.UPDATE_BEFORE or row_kind == InternalRowKind.DELETE
cdef bint is_accumulate_msg(self, object data):
return data.get_row_kind() == RowKind.UPDATE_AFTER or data.get_row_kind() == RowKind.INSERT
cdef bint is_accumulate_msg(self, InternalRowKind row_kind):
return row_kind == InternalRowKind.UPDATE_AFTER or row_kind == InternalRowKind.INSERT
cpdef list process_element(self, object input_data):
cpdef list process_element(self, InternalRow input_data):
pass
cdef class GroupAggFunction(GroupAggFunctionBase):
......@@ -466,13 +466,15 @@ cdef class GroupAggFunction(GroupAggFunctionBase):
aggs_handle, key_selector, state_backend, state_value_coder, generate_update_before,
state_cleaning_enabled, index_of_count_star)
cpdef list process_element(self, object input_data):
cpdef list process_element(self, InternalRow input_data):
cdef list results = []
cdef bint first_row
cdef list key, pre_agg_value, new_agg_value, accumulators, input_value
cdef object retract_row, result_row
cdef InternalRow retract_row, result_row
cdef SimpleAggsHandleFunction aggs_handle
input_value = input_data._values
cdef InternalRowKind input_row_kind
input_row_kind = input_data.row_kind
input_value = input_data.values
aggs_handle = <SimpleAggsHandleFunction> self.aggs_handle
key = self.key_selector.get_key(input_value)
self.state_backend.set_current_key(key)
......@@ -481,7 +483,7 @@ cdef class GroupAggFunction(GroupAggFunctionBase):
"accumulators", self.state_value_coder)
accumulators = accumulator_state.value()
if accumulators is None:
if self.is_retract_msg(input_data):
if self.is_retract_msg(input_row_kind):
# Don't create a new accumulator for a retraction message. This might happen if the
# retraction message is the first message for the key or after a state clean up.
return
......@@ -496,7 +498,7 @@ cdef class GroupAggFunction(GroupAggFunctionBase):
pre_agg_value = aggs_handle.get_value()
# update aggregate result and set to the newRow
if self.is_accumulate_msg(input_data):
if self.is_accumulate_msg(input_row_kind):
# accumulate input
aggs_handle.accumulate(input_value)
else:
......@@ -527,25 +529,21 @@ cdef class GroupAggFunction(GroupAggFunctionBase):
# retract previous result
if self.generate_update_before:
# prepare UPDATE_BEFORE message for previous row
retract_row = join_row(key, pre_agg_value)
retract_row.set_row_kind(RowKind.UPDATE_BEFORE)
retract_row = join_row(key, pre_agg_value, InternalRowKind.UPDATE_BEFORE)
results.append(retract_row)
# prepare UPDATE_AFTER message for new row
result_row = join_row(key, new_agg_value)
result_row.set_row_kind(RowKind.UPDATE_AFTER)
result_row = join_row(key, new_agg_value, InternalRowKind.UPDATE_AFTER)
else:
# this is the first, output new result
# prepare INSERT message for new row
result_row = join_row(key, new_agg_value)
result_row.set_row_kind(RowKind.INSERT)
result_row = join_row(key, new_agg_value, InternalRowKind.INSERT)
results.append(result_row)
else:
# we retracted the last record for this key
# sent out a delete message
if not first_row:
# prepare delete message for previous row
result_row = join_row(key, pre_agg_value)
result_row.set_row_kind(RowKind.DELETE)
result_row = join_row(key, pre_agg_value, InternalRowKind.DELETE)
results.append(result_row)
# and clear all state
accumulator_state.clear()
......@@ -566,12 +564,14 @@ cdef class GroupTableAggFunction(GroupAggFunctionBase):
aggs_handle, key_selector, state_backend, state_value_coder, generate_update_before,
state_cleaning_enabled, index_of_count_star)
cpdef list process_element(self, object input_data):
cpdef list process_element(self, InternalRow input_data):
cdef bint first_row
cdef list key, accumulators, input_value, results
cdef SimpleTableAggsHandleFunction aggs_handle
cdef InternalRowKind input_row_kind
results = []
input_value = input_data._values
input_value = input_data.values
input_row_kind = input_data.row_kind
aggs_handle = <SimpleTableAggsHandleFunction> self.aggs_handle
key = self.key_selector.get_key(input_value)
self.state_backend.set_current_key(key)
......@@ -592,7 +592,7 @@ cdef class GroupTableAggFunction(GroupAggFunctionBase):
results.append(aggs_handle.emit_value(key, True))
# update aggregate result and set to the newRow
if self.is_accumulate_msg(input_data):
if self.is_accumulate_msg(input_row_kind):
# accumulate input
aggs_handle.accumulate(input_value)
else:
......
......@@ -21,6 +21,16 @@ cimport libc.stdint
from pyflink.fn_execution.stream cimport LengthPrefixInputStream, LengthPrefixOutputStream
cdef enum InternalRowKind:
INSERT = 0
UPDATE_BEFORE = 1
UPDATE_AFTER = 2
DELETE = 3
cdef class InternalRow:
cdef readonly list values
cdef readonly InternalRowKind row_kind
cdef class BaseCoderImpl:
cpdef void encode_to_stream(self, value, LengthPrefixOutputStream output_stream)
cpdef decode_from_stream(self, LengthPrefixInputStream input_stream)
......@@ -91,11 +101,13 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
cdef float _decode_float(self) except? -1
cdef double _decode_double(self) except? -1
cdef bytes _decode_bytes(self)
cdef object _decode_field_row(self, RowCoderImpl field_coder)
cdef class AggregateFunctionRowCoderImpl(FlattenRowCoderImpl):
cdef bint _is_row_data
cdef bint _is_first_row
cdef void _encode_list_value(self, list list_value, LengthPrefixOutputStream output_stream)
cdef void _encode_internal_row(self, InternalRow row, LengthPrefixOutputStream output_stream)
cdef class TableFunctionRowCoderImpl(FlattenRowCoderImpl):
cdef char* _end_message
......@@ -210,6 +222,7 @@ cdef class MapCoderImpl(FieldCoder):
cdef class RowCoderImpl(FieldCoder):
cdef readonly list field_coders
cdef readonly list field_names
cdef readonly size_t field_count
cdef class TupleCoderImpl(FieldCoder):
cdef readonly list field_coders
......@@ -28,6 +28,22 @@ import decimal
from pyflink.table import Row
from pyflink.table.types import RowKind
cdef class InternalRow:
def __cinit__(self, list values, InternalRowKind row_kind):
self.values = values
self.row_kind = row_kind
def __eq__(self, other):
if not other:
return False
return self.values == other.values
def __getitem__(self, item):
return self.values[item]
def __iter__(self):
return self.values
cdef class BaseCoderImpl:
cpdef void encode_to_stream(self, value, LengthPrefixOutputStream output_stream):
pass
......@@ -71,17 +87,48 @@ cdef class AggregateFunctionRowCoderImpl(FlattenRowCoderImpl):
cdef void _encode_list_value(self, list results, LengthPrefixOutputStream output_stream):
cdef list result
cdef InternalRow value
if self._is_first_row and results:
self._is_row_data = isinstance(results[0], Row)
self._is_row_data = isinstance(results[0], InternalRow)
self._is_first_row = False
if self._is_row_data:
for value in results:
self._encode_one_row_with_row_kind(value, output_stream, value.get_row_kind().value)
self._encode_internal_row(value, output_stream)
else:
for result in results:
for item in result:
self._encode_one_row_with_row_kind(
item, output_stream, item.get_row_kind().value)
for value in result:
self._encode_internal_row(value, output_stream)
cdef void _encode_internal_row(self, InternalRow row, LengthPrefixOutputStream output_stream):
self._encode_one_row_to_buffer(row.values, row.row_kind)
output_stream.write(self._tmp_output_data, self._tmp_output_pos)
self._tmp_output_pos = 0
cdef InternalRow _decode_field_row(self, RowCoderImpl field_coder):
cdef list row_field_coders
cdef size_t row_field_count, leading_complete_bytes_num, remaining_bits_num
cdef bint*mask
cdef unsigned char row_kind_value
cdef libc.stdint.int32_t i
cdef InternalRow row
cdef FieldCoder row_field_coder
row_field_coders = field_coder.field_coders
row_field_count = field_coder.field_count
mask = <bint*> malloc((row_field_count + ROW_KIND_BIT_SIZE) * sizeof(bint))
leading_complete_bytes_num = (row_field_count + ROW_KIND_BIT_SIZE) // 8
remaining_bits_num = (row_field_count + ROW_KIND_BIT_SIZE) % 8
self._read_mask(mask, leading_complete_bytes_num, remaining_bits_num)
row_kind_value = 0
for i in range(ROW_KIND_BIT_SIZE):
row_kind_value += mask[i] * 2 ** i
row = InternalRow([None if mask[i + ROW_KIND_BIT_SIZE] else
self._decode_field(
row_field_coders[i].coder_type(),
row_field_coders[i].type_name(),
row_field_coders[i])
for i in range(row_field_count)], row_kind_value)
free(mask)
return row
cdef class DataStreamFlatMapCoderImpl(BaseCoderImpl):
......@@ -371,13 +418,10 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
cdef object _decode_field_complex(self, TypeName field_type, FieldCoder field_coder):
cdef libc.stdint.int32_t nanoseconds, microseconds, seconds, length
cdef libc.stdint.int32_t i, row_field_count, leading_complete_bytes_num, remaining_bits_num
cdef libc.stdint.int64_t milliseconds
cdef bint*null_mask
cdef FieldCoder value_coder, key_coder
cdef TypeName value_type, key_type
cdef CoderType value_coder_type, key_coder_type
cdef list row_field_coders
if field_type == DECIMAL:
# decimal
......@@ -446,26 +490,33 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
return map_value
elif field_type == ROW:
# Row
row_field_coders = (<RowCoderImpl> field_coder).field_coders
row_field_names = (<RowCoderImpl> field_coder).field_names
row_field_count = len(row_field_coders)
mask = <bint*> malloc((row_field_count + ROW_KIND_BIT_SIZE) * sizeof(bint))
leading_complete_bytes_num = (row_field_count + ROW_KIND_BIT_SIZE) // 8
remaining_bits_num = (row_field_count + ROW_KIND_BIT_SIZE) % 8
self._read_mask(mask, leading_complete_bytes_num, remaining_bits_num)
row = Row(*[None if mask[i + ROW_KIND_BIT_SIZE] else
self._decode_field(
row_field_coders[i].coder_type(),
row_field_coders[i].type_name(),
row_field_coders[i])
for i in range(row_field_count)])
row.set_field_names(row_field_names)
row_kind_value = 0
for i in range(ROW_KIND_BIT_SIZE):
row_kind_value += mask[i] * 2 ** i
row.set_row_kind(RowKind(row_kind_value))
free(mask)
return row
return self._decode_field_row(field_coder)
cdef object _decode_field_row(self, RowCoderImpl field_coder):
cdef list row_field_coders, row_field_names
cdef size_t row_field_count, leading_complete_bytes_num, remaining_bits_num, i
cdef bint*mask
cdef unsigned char row_kind_value
row_field_coders = (<RowCoderImpl> field_coder).field_coders
row_field_names = (<RowCoderImpl> field_coder).field_names
row_field_count = len(row_field_coders)
mask = <bint*> malloc((row_field_count + ROW_KIND_BIT_SIZE) * sizeof(bint))
leading_complete_bytes_num = (row_field_count + ROW_KIND_BIT_SIZE) // 8
remaining_bits_num = (row_field_count + ROW_KIND_BIT_SIZE) % 8
self._read_mask(mask, leading_complete_bytes_num, remaining_bits_num)
row = Row(*[None if mask[i + ROW_KIND_BIT_SIZE] else
self._decode_field(
row_field_coders[i].coder_type(),
row_field_coders[i].type_name(),
row_field_coders[i])
for i in range(row_field_count)])
row.set_field_names(row_field_names)
row_kind_value = 0
for i in range(ROW_KIND_BIT_SIZE):
row_kind_value += mask[i] * 2 ** i
row.set_row_kind(RowKind(row_kind_value))
free(mask)
return row
cdef unsigned char _decode_byte(self) except? -1:
self._input_pos += 1
......@@ -896,6 +947,7 @@ cdef class RowCoderImpl(FieldCoder):
def __cinit__(self, field_coders, field_names):
self.field_coders = field_coders
self.field_names = field_names
self.field_count = len(self.field_coders)
cpdef CoderType coder_type(self):
return COMPLEX
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册