/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ module thrift.protocol.compact; import std.array : uninitializedArray; import std.typetuple : allSatisfy, TypeTuple; import thrift.protocol.base; import thrift.transport.base; import thrift.internal.endian; /** * D implementation of the Compact protocol. * * See THRIFT-110 for a protocol description. This implementation is based on * the C++ one. */ final class TCompactProtocol(Transport = TTransport) if ( isTTransport!Transport ) : TProtocol { /** * Constructs a new instance. * * Params: * trans = The transport to use. * containerSizeLimit = If positive, the container size is limited to the * given number of items. * stringSizeLimit = If positive, the string length is limited to the * given number of bytes. */ this(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0) { trans_ = trans; this.containerSizeLimit = containerSizeLimit; this.stringSizeLimit = stringSizeLimit; } Transport transport() @property { return trans_; } void reset() { lastFieldId_ = 0; fieldIdStack_ = null; booleanField_ = TField.init; hasBoolValue_ = false; } /** * If positive, limits the number of items of deserialized containers to the * given amount. * * This is useful to avoid allocating excessive amounts of memory when broken * data is received. If the limit is exceeded, a SIZE_LIMIT-type * TProtocolException is thrown. * * Defaults to zero (no limit). */ int containerSizeLimit; /** * If positive, limits the length of deserialized strings/binary data to the * given number of bytes. * * This is useful to avoid allocating excessive amounts of memory when broken * data is received. If the limit is exceeded, a SIZE_LIMIT-type * TProtocolException is thrown. * * Defaults to zero (no limit). */ int stringSizeLimit; /* * Writing methods. */ void writeBool(bool b) { if (booleanField_.name !is null) { // we haven't written the field header yet writeFieldBeginInternal(booleanField_, b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE); booleanField_.name = null; } else { // we're not part of a field, so just write the value writeByte(b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE); } } void writeByte(byte b) { trans_.write((cast(ubyte*)&b)[0..1]); } void writeI16(short i16) { writeVarint32(i32ToZigzag(i16)); } void writeI32(int i32) { writeVarint32(i32ToZigzag(i32)); } void writeI64(long i64) { writeVarint64(i64ToZigzag(i64)); } void writeDouble(double dub) { ulong bits = hostToLe(*cast(ulong*)(&dub)); trans_.write((cast(ubyte*)&bits)[0 .. 8]); } void writeString(string str) { writeBinary(cast(ubyte[])str); } void writeBinary(ubyte[] buf) { assert(buf.length <= int.max); writeVarint32(cast(int)buf.length); trans_.write(buf); } void writeMessageBegin(TMessage msg) { writeByte(cast(byte)PROTOCOL_ID); writeByte(cast(byte)((VERSION_N & VERSION_MASK) | ((cast(int)msg.type << TYPE_SHIFT_AMOUNT) & TYPE_MASK))); writeVarint32(msg.seqid); writeString(msg.name); } void writeMessageEnd() {} void writeStructBegin(TStruct tstruct) { fieldIdStack_ ~= lastFieldId_; lastFieldId_ = 0; } void writeStructEnd() { lastFieldId_ = fieldIdStack_[$ - 1]; fieldIdStack_ = fieldIdStack_[0 .. $ - 1]; fieldIdStack_.assumeSafeAppend(); } void writeFieldBegin(TField field) { if (field.type == TType.BOOL) { booleanField_.name = field.name; booleanField_.type = field.type; booleanField_.id = field.id; } else { return writeFieldBeginInternal(field); } } void writeFieldEnd() {} void writeFieldStop() { writeByte(TType.STOP); } void writeListBegin(TList list) { writeCollectionBegin(list.elemType, list.size); } void writeListEnd() {} void writeMapBegin(TMap map) { if (map.size == 0) { writeByte(0); } else { assert(map.size <= int.max); writeVarint32(cast(int)map.size); writeByte(cast(byte)(toCType(map.keyType) << 4 | toCType(map.valueType))); } } void writeMapEnd() {} void writeSetBegin(TSet set) { writeCollectionBegin(set.elemType, set.size); } void writeSetEnd() {} /* * Reading methods. */ bool readBool() { if (hasBoolValue_ == true) { hasBoolValue_ = false; return boolValue_; } return readByte() == CType.BOOLEAN_TRUE; } byte readByte() { ubyte[1] b = void; trans_.readAll(b); return cast(byte)b[0]; } short readI16() { return cast(short)zigzagToI32(readVarint32()); } int readI32() { return zigzagToI32(readVarint32()); } long readI64() { return zigzagToI64(readVarint64()); } double readDouble() { IntBuf!long b = void; trans_.readAll(b.bytes); b.value = leToHost(b.value); return *cast(double*)(&b.value); } string readString() { return cast(string)readBinary(); } ubyte[] readBinary() { auto size = readVarint32(); checkSize(size, stringSizeLimit); if (size == 0) { return null; } auto buf = uninitializedArray!(ubyte[])(size); trans_.readAll(buf); return buf; } TMessage readMessageBegin() { TMessage msg = void; auto protocolId = readByte(); if (protocolId != cast(byte)PROTOCOL_ID) { throw new TProtocolException("Bad protocol identifier", TProtocolException.Type.BAD_VERSION); } auto versionAndType = readByte(); auto ver = versionAndType & VERSION_MASK; if (ver != VERSION_N) { throw new TProtocolException("Bad protocol version", TProtocolException.Type.BAD_VERSION); } msg.type = cast(TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & TYPE_BITS); msg.seqid = readVarint32(); msg.name = readString(); return msg; } void readMessageEnd() {} TStruct readStructBegin() { fieldIdStack_ ~= lastFieldId_; lastFieldId_ = 0; return TStruct(); } void readStructEnd() { lastFieldId_ = fieldIdStack_[$ - 1]; fieldIdStack_ = fieldIdStack_[0 .. $ - 1]; } TField readFieldBegin() { TField f = void; f.name = null; auto bite = readByte(); auto type = cast(CType)(bite & 0x0f); if (type == CType.STOP) { // Struct stop byte, nothing more to do. f.id = 0; f.type = TType.STOP; return f; } // Mask off the 4 MSB of the type header, which could contain a field id // delta. auto modifier = cast(short)((bite & 0xf0) >> 4); if (modifier > 0) { f.id = cast(short)(lastFieldId_ + modifier); } else { // Delta encoding not used, just read the id as usual. f.id = readI16(); } f.type = getTType(type); if (type == CType.BOOLEAN_TRUE || type == CType.BOOLEAN_FALSE) { // For boolean fields, the value is encoded in the type – keep it around // for the readBool() call. hasBoolValue_ = true; boolValue_ = (type == CType.BOOLEAN_TRUE ? true : false); } lastFieldId_ = f.id; return f; } void readFieldEnd() {} TList readListBegin() { auto sizeAndType = readByte(); auto lsize = (sizeAndType >> 4) & 0xf; if (lsize == 0xf) { lsize = readVarint32(); } checkSize(lsize, containerSizeLimit); TList l = void; l.elemType = getTType(cast(CType)(sizeAndType & 0x0f)); l.size = cast(size_t)lsize; return l; } void readListEnd() {} TMap readMapBegin() { TMap m = void; auto size = readVarint32(); ubyte kvType; if (size != 0) { kvType = readByte(); } checkSize(size, containerSizeLimit); m.size = size; m.keyType = getTType(cast(CType)(kvType >> 4)); m.valueType = getTType(cast(CType)(kvType & 0xf)); return m; } void readMapEnd() {} TSet readSetBegin() { auto sizeAndType = readByte(); auto lsize = (sizeAndType >> 4) & 0xf; if (lsize == 0xf) { lsize = readVarint32(); } checkSize(lsize, containerSizeLimit); TSet s = void; s.elemType = getTType(cast(CType)(sizeAndType & 0xf)); s.size = cast(size_t)lsize; return s; } void readSetEnd() {} private: void writeFieldBeginInternal(TField field, byte typeOverride = -1) { // If there's a type override, use that. auto typeToWrite = (typeOverride == -1 ? toCType(field.type) : typeOverride); // check if we can use delta encoding for the field id if (field.id > lastFieldId_ && (field.id - lastFieldId_) <= 15) { // write them together writeByte(cast(byte)((field.id - lastFieldId_) << 4 | typeToWrite)); } else { // write them separate writeByte(cast(byte)typeToWrite); writeI16(field.id); } lastFieldId_ = field.id; } void writeCollectionBegin(TType elemType, size_t size) { if (size <= 14) { writeByte(cast(byte)(size << 4 | toCType(elemType))); } else { assert(size <= int.max); writeByte(cast(byte)(0xf0 | toCType(elemType))); writeVarint32(cast(int)size); } } void writeVarint32(uint n) { ubyte[5] buf = void; ubyte wsize; while (true) { if ((n & ~0x7F) == 0) { buf[wsize++] = cast(ubyte)n; break; } else { buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80); n >>= 7; } } trans_.write(buf[0 .. wsize]); } /* * Write an i64 as a varint. Results in 1-10 bytes on the wire. */ void writeVarint64(ulong n) { ubyte[10] buf = void; ubyte wsize; while (true) { if ((n & ~0x7FL) == 0) { buf[wsize++] = cast(ubyte)n; break; } else { buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80); n >>= 7; } } trans_.write(buf[0 .. wsize]); } /* * Convert l into a zigzag long. This allows negative numbers to be * represented compactly as a varint. */ ulong i64ToZigzag(long l) { return (l << 1) ^ (l >> 63); } /* * Convert n into a zigzag int. This allows negative numbers to be * represented compactly as a varint. */ uint i32ToZigzag(int n) { return (n << 1) ^ (n >> 31); } CType toCType(TType type) { final switch (type) { case TType.STOP: return CType.STOP; case TType.BOOL: return CType.BOOLEAN_TRUE; case TType.BYTE: return CType.BYTE; case TType.DOUBLE: return CType.DOUBLE; case TType.I16: return CType.I16; case TType.I32: return CType.I32; case TType.I64: return CType.I64; case TType.STRING: return CType.BINARY; case TType.STRUCT: return CType.STRUCT; case TType.MAP: return CType.MAP; case TType.SET: return CType.SET; case TType.LIST: return CType.LIST; case TType.VOID: assert(false, "Invalid type passed."); } } int readVarint32() { return cast(int)readVarint64(); } long readVarint64() { ulong val; ubyte shift; ubyte[10] buf = void; // 64 bits / (7 bits/byte) = 10 bytes. auto bufSize = buf.sizeof; auto borrowed = trans_.borrow(buf.ptr, bufSize); ubyte rsize; if (borrowed) { // Fast path. while (true) { auto bite = borrowed[rsize]; rsize++; val |= cast(ulong)(bite & 0x7f) << shift; shift += 7; if (!(bite & 0x80)) { trans_.consume(rsize); return val; } // Have to check for invalid data so we don't crash. if (rsize == buf.sizeof) { throw new TProtocolException(TProtocolException.Type.INVALID_DATA, "Variable-length int over 10 bytes."); } } } else { // Slow path. while (true) { ubyte[1] bite; trans_.readAll(bite); ++rsize; val |= cast(ulong)(bite[0] & 0x7f) << shift; shift += 7; if (!(bite[0] & 0x80)) { return val; } // Might as well check for invalid data on the slow path too. if (rsize >= buf.sizeof) { throw new TProtocolException(TProtocolException.Type.INVALID_DATA, "Variable-length int over 10 bytes."); } } } } /* * Convert from zigzag int to int. */ int zigzagToI32(uint n) { return (n >> 1) ^ -(n & 1); } /* * Convert from zigzag long to long. */ long zigzagToI64(ulong n) { return (n >> 1) ^ -(n & 1); } TType getTType(CType type) { final switch (type) { case CType.STOP: return TType.STOP; case CType.BOOLEAN_FALSE: return TType.BOOL; case CType.BOOLEAN_TRUE: return TType.BOOL; case CType.BYTE: return TType.BYTE; case CType.I16: return TType.I16; case CType.I32: return TType.I32; case CType.I64: return TType.I64; case CType.DOUBLE: return TType.DOUBLE; case CType.BINARY: return TType.STRING; case CType.LIST: return TType.LIST; case CType.SET: return TType.SET; case CType.MAP: return TType.MAP; case CType.STRUCT: return TType.STRUCT; } } void checkSize(int size, int limit) { if (size < 0) { throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE); } else if (limit > 0 && size > limit) { throw new TProtocolException(TProtocolException.Type.SIZE_LIMIT); } } enum PROTOCOL_ID = 0x82; enum VERSION_N = 1; enum VERSION_MASK = 0b0001_1111; enum TYPE_MASK = 0b1110_0000; enum TYPE_BITS = 0b0000_0111; enum TYPE_SHIFT_AMOUNT = 5; // Probably need to implement a better stack at some point. short[] fieldIdStack_; short lastFieldId_; TField booleanField_; bool hasBoolValue_; bool boolValue_; Transport trans_; } /** * TCompactProtocol construction helper to avoid having to explicitly specify * the transport type, i.e. to allow the constructor being called using IFTI * (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D Bugzilla * enhancement requet 6082)). */ TCompactProtocol!Transport tCompactProtocol(Transport)(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0 ) if (isTTransport!Transport) { return new TCompactProtocol!Transport(trans, containerSizeLimit, stringSizeLimit); } private { enum CType : ubyte { STOP = 0x0, BOOLEAN_TRUE = 0x1, BOOLEAN_FALSE = 0x2, BYTE = 0x3, I16 = 0x4, I32 = 0x5, I64 = 0x6, DOUBLE = 0x7, BINARY = 0x8, LIST = 0x9, SET = 0xa, MAP = 0xb, STRUCT = 0xc } static assert(CType.max <= 0xf, "Compact protocol wire type representation must fit into 4 bits."); } unittest { import std.exception; import thrift.transport.memory; // Check the message header format. auto buf = new TMemoryBuffer; auto compact = tCompactProtocol(buf); compact.writeMessageBegin(TMessage("foo", TMessageType.CALL, 0)); auto header = new ubyte[7]; buf.readAll(header); enforce(header == [ 130, // Protocol id. 33, // Version/type byte. 0, // Sequence id. 3, 102, 111, 111 // Method name. ]); } unittest { import thrift.internal.test.protocol; testContainerSizeLimit!(TCompactProtocol!())(); testStringSizeLimit!(TCompactProtocol!())(); } /** * TProtocolFactory creating a TCompactProtocol instance for passed in * transports. * * The optional Transports template tuple parameter can be used to specify * one or more TTransport implementations to specifically instantiate * TCompactProtocol for. If the actual transport types encountered at * runtime match one of the transports in the list, a specialized protocol * instance is created. Otherwise, a generic TTransport version is used. */ class TCompactProtocolFactory(Transports...) if ( allSatisfy!(isTTransport, Transports) ) : TProtocolFactory { /// this(int containerSizeLimit = 0, int stringSizeLimit = 0) { containerSizeLimit_ = 0; stringSizeLimit_ = 0; } TProtocol getProtocol(TTransport trans) const { foreach (Transport; TypeTuple!(Transports, TTransport)) { auto concreteTrans = cast(Transport)trans; if (concreteTrans) { return new TCompactProtocol!Transport(concreteTrans); } } throw new TProtocolException( "Passed null transport to TCompactProtocolFactory."); } int containerSizeLimit_; int stringSizeLimit_; }