提交 4f46052f 编写于 作者: V Vitaly Baranov

Optimize the class ProtobufWriter and prepare for implementing support of nested messages.

上级 1cce1fea
......@@ -69,6 +69,12 @@ namespace ProtobufColumnMatcher
const std::vector<String> & column_names,
const google::protobuf::Descriptor * message_type);
template <typename Traits = DefaultTraits>
static std::unique_ptr<Message<Traits>> matchColumns(
const std::vector<String> & column_names,
const google::protobuf::Descriptor * message_type,
std::vector<const google::protobuf::FieldDescriptor *> & field_descriptors_without_match);
namespace details
{
void throwNoCommonColumns();
......@@ -88,7 +94,8 @@ namespace ProtobufColumnMatcher
std::unique_ptr<Message<Traits>> matchColumnsRecursive(
ColumnNameMatcher & name_matcher,
const google::protobuf::Descriptor * message_type,
const String & field_name_prefix)
const String & field_name_prefix,
std::vector<const google::protobuf::FieldDescriptor *> * field_descriptors_without_match)
{
auto message = std::make_unique<Message<Traits>>();
for (int i = 0; i != message_type->field_count(); ++i)
......@@ -98,7 +105,10 @@ namespace ProtobufColumnMatcher
|| (field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_GROUP))
{
auto nested_message = matchColumnsRecursive<Traits>(
name_matcher, field_descriptor->message_type(), field_name_prefix + field_descriptor->name() + ".");
name_matcher,
field_descriptor->message_type(),
field_name_prefix + field_descriptor->name() + ".",
field_descriptors_without_match);
if (nested_message)
{
message->fields.emplace_back();
......@@ -112,7 +122,12 @@ namespace ProtobufColumnMatcher
else
{
size_t column_index = name_matcher.findColumn(field_name_prefix + field_descriptor->name());
if (column_index != static_cast<size_t>(-1))
if (column_index == static_cast<size_t>(-1))
{
if (field_descriptors_without_match)
field_descriptors_without_match->emplace_back(field_descriptor);
}
else
{
message->fields.emplace_back();
auto & current_field = message->fields.back();
......@@ -144,16 +159,34 @@ namespace ProtobufColumnMatcher
}
template <typename Data>
static std::unique_ptr<Message<Data>> matchColumns(
static std::unique_ptr<Message<Data>> matchColumnsImpl(
const std::vector<String> & column_names,
const google::protobuf::Descriptor * message_type)
const google::protobuf::Descriptor * message_type,
std::vector<const google::protobuf::FieldDescriptor *> * field_descriptors_without_match)
{
details::ColumnNameMatcher name_matcher(column_names);
auto message = details::matchColumnsRecursive<Data>(name_matcher, message_type, "");
auto message = details::matchColumnsRecursive<Data>(name_matcher, message_type, "", field_descriptors_without_match);
if (!message)
details::throwNoCommonColumns();
return message;
}
template <typename Data>
static std::unique_ptr<Message<Data>> matchColumns(
const std::vector<String> & column_names,
const google::protobuf::Descriptor * message_type)
{
return matchColumnsImpl<Data>(column_names, message_type, nullptr);
}
template <typename Data>
static std::unique_ptr<Message<Data>> matchColumns(
const std::vector<String> & column_names,
const google::protobuf::Descriptor * message_type,
std::vector<const google::protobuf::FieldDescriptor *> & field_descriptors_without_match)
{
return matchColumnsImpl<Data>(column_names, message_type, &field_descriptors_without_match);
}
}
}
......
......@@ -9,59 +9,23 @@
#include <Formats/BlockOutputStreamFromRowOutputStream.h>
#include <Formats/FormatSchemaInfo.h>
#include <Formats/ProtobufSchemas.h>
#include <Interpreters/Context.h>
#include <google/protobuf/descriptor.h>
namespace DB
{
namespace ErrorCodes
ProtobufRowOutputStream::ProtobufRowOutputStream(WriteBuffer & out, const Block & header, const FormatSchemaInfo & info)
: data_types(header.getDataTypes()), writer(out, ProtobufSchemas::instance().getMessageTypeForFormatSchema(info), header.getNames())
{
extern const int NOT_IMPLEMENTED;
extern const int NO_DATA_FOR_REQUIRED_PROTOBUF_FIELD;
}
ProtobufRowOutputStream::ProtobufRowOutputStream(
WriteBuffer & buffer_,
const Block & header,
const google::protobuf::Descriptor * message_type)
: writer(buffer_, message_type)
{
std::vector<const ColumnWithTypeAndName *> columns_in_write_order;
const auto & fields_in_write_order = writer.fieldsInWriteOrder();
column_indices.reserve(fields_in_write_order.size());
data_types.reserve(fields_in_write_order.size());
for (size_t i = 0; i != fields_in_write_order.size(); ++i)
{
const auto * field = fields_in_write_order[i];
size_t column_index = static_cast<size_t>(-1);
DataTypePtr data_type = nullptr;
if (header.has(field->name()))
{
column_index = header.getPositionByName(field->name());
data_type = header.getByPosition(column_index).type;
}
else if (field->is_required())
{
throw Exception(
"Output doesn't have a column named '" + field->name() + "' which is required to write the output in the protobuf format.",
ErrorCodes::NO_DATA_FOR_REQUIRED_PROTOBUF_FIELD);
}
column_indices.emplace_back(column_index);
data_types.emplace_back(data_type);
}
}
void ProtobufRowOutputStream::write(const Block & block, size_t row_num)
{
writer.newMessage();
for (size_t i = 0; i != data_types.size(); ++i)
{
if (data_types[i])
data_types[i]->serializeProtobuf(*block.getByPosition(column_indices[i]).column, row_num, writer);
writer.nextField();
}
writer.startMessage();
size_t column_index;
while (writer.writeField(column_index))
data_types[column_index]->serializeProtobuf(*block.getByPosition(column_index).column, row_num, writer);
writer.endMessage();
}
......@@ -70,9 +34,8 @@ void registerOutputFormatProtobuf(FormatFactory & factory)
factory.registerOutputFormat(
"Protobuf", [](WriteBuffer & buf, const Block & header, const Context & context, const FormatSettings &)
{
const auto * message_type = ProtobufSchemas::instance().getMessageTypeForFormatSchema(FormatSchemaInfo(context, "proto"));
return std::make_shared<BlockOutputStreamFromRowOutputStream>(
std::make_shared<ProtobufRowOutputStream>(buf, header, message_type), header);
std::make_shared<ProtobufRowOutputStream>(buf, header, FormatSchemaInfo(context, "proto")), header);
});
}
......
#pragma once
#include <Core/Block.h>
#include <DataTypes/IDataType.h>
#include <Formats/IRowOutputStream.h>
#include <Formats/FormatSettings.h>
#include <Formats/ProtobufWriter.h>
......@@ -17,6 +16,9 @@ namespace protobuf
namespace DB
{
class Block;
class FormatSchemaInfo;
/** Stream designed to serialize data in the google protobuf format.
* Each row is written as a separated message.
* These messages are delimited according to documentation
......@@ -28,18 +30,14 @@ namespace DB
class ProtobufRowOutputStream : public IRowOutputStream
{
public:
ProtobufRowOutputStream(
WriteBuffer & buffer_,
const Block & header_,
const google::protobuf::Descriptor * message_prototype_);
ProtobufRowOutputStream(WriteBuffer & out, const Block & header, const FormatSchemaInfo & info);
void write(const Block & block, size_t row_num) override;
std::string getContentType() const override { return "application/octet-stream"; }
private:
ProtobufWriter writer;
DataTypes data_types;
std::vector<size_t> column_indices;
ProtobufWriter writer;
};
}
此差异已折叠。
#pragma once
#include <common/DayNum.h>
#include <Common/UInt128.h>
#include <Core/UUID.h>
#include <Common/UInt128.h>
#include <common/DayNum.h>
#include <Common/config.h>
#if USE_PROTOBUF
#include <boost/noncopyable.hpp>
#include <Formats/ProtobufColumnMatcher.h>
#include <IO/WriteBufferFromString.h>
#include <boost/noncopyable.hpp>
#include <Common/PODArray.h>
namespace google
......@@ -32,46 +34,46 @@ using ConstAggregateDataPtr = const char *;
class ProtobufWriter : private boost::noncopyable
{
public:
ProtobufWriter(WriteBuffer & out, const google::protobuf::Descriptor * message_type);
ProtobufWriter(WriteBuffer & out, const google::protobuf::Descriptor * message_type, const std::vector<String> & column_names);
~ProtobufWriter();
/// Returns fields of the protobuf schema sorted by their numbers.
const std::vector<const google::protobuf::FieldDescriptor *> & fieldsInWriteOrder() const;
/// Should be called when we start writing a new message.
void newMessage();
/// Should be called when we start writing a new field.
/// Returns false if there is no more fields in the message type.
bool nextField();
/// Returns the current field of the message type.
/// The value returned by this function changes after calling nextField() or newMessage().
const google::protobuf::FieldDescriptor * currentField() const { return current_field; }
void writeNumber(Int8 value) { current_converter->writeInt8(value); }
void writeNumber(UInt8 value) { current_converter->writeUInt8(value); }
void writeNumber(Int16 value) { current_converter->writeInt16(value); }
void writeNumber(UInt16 value) { current_converter->writeUInt16(value); }
void writeNumber(Int32 value) { current_converter->writeInt32(value); }
void writeNumber(UInt32 value) { current_converter->writeUInt32(value); }
void writeNumber(Int64 value) { current_converter->writeInt64(value); }
void writeNumber(UInt64 value) { current_converter->writeUInt64(value); }
void writeNumber(UInt128 value) { current_converter->writeUInt128(value); }
void writeNumber(Float32 value) { current_converter->writeFloat32(value); }
void writeNumber(Float64 value) { current_converter->writeFloat64(value); }
void writeString(const StringRef & str) { current_converter->writeString(str); }
void prepareEnumMapping(const std::vector<std::pair<std::string, Int8>> & enum_values) { current_converter->prepareEnumMappingInt8(enum_values); }
void prepareEnumMapping(const std::vector<std::pair<std::string, Int16>> & enum_values) { current_converter->prepareEnumMappingInt16(enum_values); }
void writeEnum(Int8 value) { current_converter->writeEnumInt8(value); }
void writeEnum(Int16 value) { current_converter->writeEnumInt16(value); }
void writeUUID(const UUID & uuid) { current_converter->writeUUID(uuid); }
void writeDate(DayNum date) { current_converter->writeDate(date); }
void writeDateTime(time_t tm) { current_converter->writeDateTime(tm); }
void writeDecimal(Decimal32 decimal, UInt32 scale) { current_converter->writeDecimal32(decimal, scale); }
void writeDecimal(Decimal64 decimal, UInt32 scale) { current_converter->writeDecimal64(decimal, scale); }
void writeDecimal(const Decimal128 & decimal, UInt32 scale) { current_converter->writeDecimal128(decimal, scale); }
void writeAggregateFunction(const AggregateFunctionPtr & function, ConstAggregateDataPtr place) { current_converter->writeAggregateFunction(function, place); }
/// Should be called at the beginning of writing a message.
void startMessage();
/// Should be called at the end of writing a message.
void endMessage();
/// Prepares for writing values of a field.
/// Returns true and sets 'column_index' to the corresponding column's index.
/// Returns false if there are no more fields to write in the message type (call endMessage() in this case).
bool writeField(size_t & column_index);
/// Writes a value. This function should be called one or multiple times after writeField().
/// Returns false if there are no more place for the values in the protobuf's field.
/// This can happen if the protobuf's field is not declared as repeated in the protobuf schema.
void writeNumber(Int8 value) { writeValue(&IConverter::writeInt8, value); }
void writeNumber(UInt8 value) { writeValue(&IConverter::writeUInt8, value); }
void writeNumber(Int16 value) { writeValue(&IConverter::writeInt16, value); }
void writeNumber(UInt16 value) { writeValue(&IConverter::writeUInt16, value); }
void writeNumber(Int32 value) { writeValue(&IConverter::writeInt32, value); }
void writeNumber(UInt32 value) { writeValue(&IConverter::writeUInt32, value); }
void writeNumber(Int64 value) { writeValue(&IConverter::writeInt64, value); }
void writeNumber(UInt64 value) { writeValue(&IConverter::writeUInt64, value); }
void writeNumber(UInt128 value) { writeValue(&IConverter::writeUInt128, value); }
void writeNumber(Float32 value) { writeValue(&IConverter::writeFloat32, value); }
void writeNumber(Float64 value) { writeValue(&IConverter::writeFloat64, value); }
void writeString(const StringRef & str) { writeValue(&IConverter::writeString, str); }
void prepareEnumMapping(const std::vector<std::pair<std::string, Int8>> & enum_values) { current_converter->prepareEnumMapping8(enum_values); }
void prepareEnumMapping(const std::vector<std::pair<std::string, Int16>> & enum_values) { current_converter->prepareEnumMapping16(enum_values); }
void writeEnum(Int8 value) { writeValue(&IConverter::writeEnum8, value); }
void writeEnum(Int16 value) { writeValue(&IConverter::writeEnum16, value); }
void writeUUID(const UUID & uuid) { writeValue(&IConverter::writeUUID, uuid); }
void writeDate(DayNum date) { writeValue(&IConverter::writeDate, date); }
void writeDateTime(time_t tm) { writeValue(&IConverter::writeDateTime, tm); }
void writeDecimal(Decimal32 decimal, UInt32 scale) { writeValue(&IConverter::writeDecimal32, decimal, scale); }
void writeDecimal(Decimal64 decimal, UInt32 scale) { writeValue(&IConverter::writeDecimal64, decimal, scale); }
void writeDecimal(const Decimal128 & decimal, UInt32 scale) { writeValue(&IConverter::writeDecimal128, decimal, scale); }
void writeAggregateFunction(const AggregateFunctionPtr & function, ConstAggregateDataPtr place) { writeValue(&IConverter::writeAggregateFunction, function, place); }
private:
class SimpleWriter
......@@ -80,66 +82,38 @@ private:
SimpleWriter(WriteBuffer & out_);
~SimpleWriter();
void newMessage();
void setCurrentField(UInt32 field_number);
UInt32 currentFieldNumber() const { return current_field_number; }
size_t numValues() const { return num_normal_values + num_packed_values; }
void writeInt32(Int32 value);
void writeUInt32(UInt32 value);
void writeSInt32(Int32 value);
void writeInt64(Int64 value);
void writeUInt64(UInt64 value);
void writeSInt64(Int64 value);
void writeFixed32(UInt32 value);
void writeSFixed32(Int32 value);
void writeFloat(float value);
void writeFixed64(UInt64 value);
void writeSFixed64(Int64 value);
void writeDouble(double value);
void writeString(const StringRef & str);
void writeInt32IfNonZero(Int32 value);
void writeUInt32IfNonZero(UInt32 value);
void writeSInt32IfNonZero(Int32 value);
void writeInt64IfNonZero(Int64 value);
void writeUInt64IfNonZero(UInt64 value);
void writeSInt64IfNonZero(Int64 value);
void writeFixed32IfNonZero(UInt32 value);
void writeSFixed32IfNonZero(Int32 value);
void writeFloatIfNonZero(float value);
void writeFixed64IfNonZero(UInt64 value);
void writeSFixed64IfNonZero(Int64 value);
void writeDoubleIfNonZero(double value);
void writeStringIfNotEmpty(const StringRef & str);
void packRepeatedInt32(Int32 value);
void packRepeatedUInt32(UInt32 value);
void packRepeatedSInt32(Int32 value);
void packRepeatedInt64(Int64 value);
void packRepeatedUInt64(UInt64 value);
void packRepeatedSInt64(Int64 value);
void packRepeatedFixed32(UInt32 value);
void packRepeatedSFixed32(Int32 value);
void packRepeatedFloat(float value);
void packRepeatedFixed64(UInt64 value);
void packRepeatedSFixed64(Int64 value);
void packRepeatedDouble(double value);
void startMessage();
void endMessage();
private:
void finishCurrentMessage();
void finishCurrentField();
void writeInt(UInt32 field_number, Int64 value);
void writeUInt(UInt32 field_number, UInt64 value);
void writeSInt(UInt32 field_number, Int64 value);
template <typename T>
void writeFixed(UInt32 field_number, T value);
void writeString(UInt32 field_number, const StringRef & str);
enum WireType : UInt32;
void writeKey(WireType wire_type, WriteBuffer & buf);
void startRepeatedPack();
void addIntToRepeatedPack(Int64 value);
void addUIntToRepeatedPack(UInt64 value);
void addSIntToRepeatedPack(Int64 value);
template <typename T>
void addFixedToRepeatedPack(T value);
void endRepeatedPack(UInt32 field_number);
private:
struct Piece
{
size_t start;
size_t end;
Piece(size_t start, size_t end) : start(start), end(end) {}
Piece() = default;
};
WriteBuffer & out;
bool were_messages = false;
WriteBufferFromOwnString message_buffer;
UInt32 current_field_number = 0;
size_t num_normal_values = 0;
size_t num_packed_values = 0;
WriteBufferFromOwnString repeated_packing_buffer;
PODArray<UInt8> buffer;
std::vector<Piece> pieces;
size_t current_piece_start;
size_t num_bytes_skipped;
};
class IConverter
......@@ -158,10 +132,10 @@ private:
virtual void writeUInt128(const UInt128 &) = 0;
virtual void writeFloat32(Float32) = 0;
virtual void writeFloat64(Float64) = 0;
virtual void prepareEnumMappingInt8(const std::vector<std::pair<std::string, Int8>> &) = 0;
virtual void prepareEnumMappingInt16(const std::vector<std::pair<std::string, Int16>> &) = 0;
virtual void writeEnumInt8(Int8) = 0;
virtual void writeEnumInt16(Int16) = 0;
virtual void prepareEnumMapping8(const std::vector<std::pair<std::string, Int8>> &) = 0;
virtual void prepareEnumMapping16(const std::vector<std::pair<std::string, Int16>> &) = 0;
virtual void writeEnum8(Int8) = 0;
virtual void writeEnum16(Int16) = 0;
virtual void writeUUID(const UUID &) = 0;
virtual void writeDate(DayNum) = 0;
virtual void writeDateTime(time_t) = 0;
......@@ -172,23 +146,54 @@ private:
};
class ConverterBaseImpl;
template <int type_id> class ConverterImpl;
template <bool skip_null_value>
class ConverterToString;
template <typename ToType>
template <int field_type_id, typename ToType, bool skip_null_value, bool pack_repeated>
class ConverterToNumber;
template <bool skip_null_value, bool pack_repeated>
class ConverterToBool;
template <bool skip_null_value, bool pack_repeated>
class ConverterToEnum;
struct ColumnMatcherTraits
{
struct FieldData
{
std::unique_ptr<IConverter> converter;
bool is_required;
bool is_repeatable;
bool should_pack_repeated;
};
struct MessageData {};
};
using Message = ProtobufColumnMatcher::Message<ColumnMatcherTraits>;
using Field = ProtobufColumnMatcher::Field<ColumnMatcherTraits>;
void setTraitsDataAfterMatchingColumns(Message * message);
template <int field_type_id>
std::unique_ptr<IConverter> createConverter(const google::protobuf::FieldDescriptor * field);
template <typename... Params>
using WriteValueFunctionPtr = void (IConverter::*)(Params...);
template <typename... Params, typename... Args>
void writeValue(WriteValueFunctionPtr<Params...> func, Args &&... args)
{
(current_converter->*func)(std::forward<Args>(args)...);
++num_values;
}
void enumerateFieldsInWriteOrder(const google::protobuf::Descriptor * message_type);
void createConverters();
void finishCurrentMessage();
void finishCurrentField();
void endWritingField();
SimpleWriter simple_writer;
std::vector<const google::protobuf::FieldDescriptor *> fields_in_write_order;
size_t current_field_index = -1;
const google::protobuf::FieldDescriptor * current_field = nullptr;
std::unique_ptr<Message> root_message;
std::vector<std::unique_ptr<IConverter>> converters;
bool writing_message = false;
size_t current_field_index = 0;
const Field * current_field = nullptr;
IConverter * current_converter = nullptr;
size_t num_values = 0;
};
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册