/** * @file ArithmeticCoding.c * @author Sheng Di, Mark Thomas Nelson * @date April, 2016 * @brief Byte Toolkit * (C) 2016 by Mathematics and Computer Science (MCS), Argonne National Laboratory. * See COPYRIGHT in top-level directory. * (C) The MIT License (MIT), this code was modified from Mark's arithmetic coding code: http://www.drdobbs.com/cpp/data-compression-with-arithmetic-encodin/240169251?pgno=1 */ #include #include inline void output_bit_1(unsigned int* buf) { (*buf) = (*buf) << 1; (*buf) |= 1; } inline void output_bit_0(unsigned int* buf) { (*buf) = (*buf) << 1; //(*byte) |= 0; //actually doesn't have to set the bit to 0 } //TODO: problematic inline unsigned int output_bit_1_plus_pending(int pending_bits) { unsigned int buf = 0, pbits = pending_bits; output_bit_1(&buf); while(pbits--) output_bit_0(&buf); buf = buf << (32-(pending_bits+1)); //alignment to the left leading bit, which would be easier for the final output return buf; } inline unsigned int output_bit_0_plus_pending(int pending_bits) { unsigned int buf = 0, pbits = pending_bits; //output_bit_0(&buf); while(pbits--) output_bit_1(&buf); buf = buf << (32-(pending_bits+1)); //alignment to the left leading bit return buf; } /** * Create AriCoder for the following arithmetic encoding operation. * In this function, it will compute the real frequency of the integer codes. * @param int numOfStates (input): numOfStates is the real # states calculated to the optimization_num_of_interval code * @param int *s (input): the integer code array (i.e., type_array generated by prediction+quantization) * @param size_t length: the number of integer codes in the type_array * * */ AriCoder *createAriCoder(int numOfStates, int *s, size_t length) { AriCoder *ariCoder = (AriCoder*)malloc(sizeof(AriCoder)); memset(ariCoder, 0, sizeof(AriCoder)); ariCoder->numOfRealStates = numOfStates; ari_init(ariCoder, s, length); return ariCoder; } void freeAriCoder(AriCoder *ariCoder) { free(ariCoder->cumulative_frequency); free(ariCoder); } void ari_init(AriCoder *ariCoder, int *s, size_t length) { size_t i; //# states is in the range of integer. int index = 0; size_t *freq = (size_t *)malloc(ariCoder->numOfRealStates*sizeof(size_t)); memset(freq, 0, ariCoder->numOfRealStates*sizeof(size_t)); for(i = 0;i < length;i++) { index = s[i]; freq[index]++; } int counter = 0; size_t _sum = 0, sum = 0, freqDiv = 0; ariCoder->cumulative_frequency = (Prob *)malloc(ariCoder->numOfRealStates*sizeof(Prob)); memset(ariCoder->cumulative_frequency, 0, ariCoder->numOfRealStates*sizeof(Prob)); if(length <= MAX_INTERVALS) { for (index = 0; index < ariCoder->numOfRealStates; index++) { if (freq[index]) { sum += freq[index]; (ariCoder->cumulative_frequency[index]).low = _sum; (ariCoder->cumulative_frequency[index]).high = sum; (ariCoder->cumulative_frequency[index]).state = index; _sum = sum; counter++; } } ariCoder->numOfValidStates = counter; ariCoder->total_frequency = sum; } else { int intvSize = length%MAX_INTERVALS==0?length/MAX_INTERVALS:length/MAX_INTERVALS+1; for (index = 0; index < ariCoder->numOfRealStates; index++) { if (freq[index]) { freqDiv = freq[index]/intvSize; //control the sum of frequency to be no greater than MAX_INTERVALS if(freqDiv==0) freqDiv = 1; sum += freqDiv; (ariCoder->cumulative_frequency[index]).low = _sum; (ariCoder->cumulative_frequency[index]).high = sum; (ariCoder->cumulative_frequency[index]).state = index; _sum = sum; counter++; } } ariCoder->numOfValidStates = counter; ariCoder->total_frequency = sum; } free(freq); } /** * Convert AriCoder to bytes for storage * @param AriCoder* ariCoder (input) * @param unsigned char** out (output) * * @return outSize * */ unsigned int pad_ariCoder(AriCoder* ariCoder, unsigned char** out) { int numOfRealStates = ariCoder->numOfRealStates; int numOfValidStates = ariCoder->numOfValidStates; uint64_t total_frequency = ariCoder->total_frequency; Prob* cumulative_frequency = ariCoder->cumulative_frequency; unsigned int outSize = 0; *out = (unsigned char*)malloc(2*sizeof(int)+sizeof(uint64_t)+sizeof(Prob)*numOfRealStates); unsigned char* p = *out; intToBytes_bigEndian(p, numOfRealStates); p+=sizeof(int); intToBytes_bigEndian(p, numOfValidStates); p+=sizeof(int); int64ToBytes_bigEndian(p, total_frequency); p+=sizeof(uint64_t); size_t i = 0; if(total_frequency <= 65536) { uint16_t low, high; if(numOfRealStates<=256) { for(i=0;inumOfValidStates*5; //2*sizeof(uint16_t)+1 } else if(numOfRealStates<=65536) { for(i=0;inumOfValidStates*6; } else { for(i=0;inumOfValidStates*8; } } else if(total_frequency <=4294967296) { uint32_t low, high; if(numOfRealStates<=256) { for(i=0;inumOfValidStates*9; } else if(numOfRealStates<=65536) { for(i=0;inumOfValidStates*10; } else { for(i=0;inumOfValidStates*12; } } else { uint64_t low, high; if(numOfRealStates<=256) { for(i=0;inumOfValidStates*17; } else if(numOfRealStates<=65536) { for(i=0;inumOfValidStates*18; } else { for(i=0;inumOfValidStates*20; } } return outSize; } /** * Reconstruct AriCoder based on the bytes loaded from compressed data * @param AriCoder** ariCoder (ourput) * @param unsigned char* bytes (input) * * @return offset * */ int unpad_ariCoder(AriCoder** ariCoder, unsigned char* bytes) { int offset = 0; *ariCoder = (AriCoder*)malloc(sizeof(AriCoder)); memset(*ariCoder, 0, sizeof(AriCoder)); unsigned char *p = bytes; int numOfRealStates = (*ariCoder)->numOfRealStates = bytesToInt_bigEndian(p); p += sizeof(int); int numOfValidStates = (*ariCoder)->numOfValidStates = bytesToInt_bigEndian(p); p += sizeof(int); size_t total_frequency = (*ariCoder)->total_frequency = bytesToInt64_bigEndian(p); p += sizeof(uint64_t); (*ariCoder)->cumulative_frequency = (Prob*)malloc((*ariCoder)->numOfRealStates*sizeof(Prob)); memset((*ariCoder)->cumulative_frequency, 0, (*ariCoder)->numOfRealStates*sizeof(Prob)); size_t i = 0; unsigned char *low_p = NULL, *high_p = NULL, *state_p = NULL; int state = 0; if(total_frequency <= 65536) { if(numOfRealStates<=256) { for(i=0;icumulative_frequency[state].low = bytesToUInt16_bigEndian(low_p); (*ariCoder)->cumulative_frequency[state].high = bytesToUInt16_bigEndian(high_p); (*ariCoder)->cumulative_frequency[state].state = state; p = state_p + 1; } offset = 2*sizeof(int)+sizeof(uint64_t)+(*ariCoder)->numOfValidStates*5; //2*sizeof(uint16_t)+1 } else if(numOfRealStates<=65536) { for(i=0;icumulative_frequency[state].low = bytesToUInt16_bigEndian(low_p); (*ariCoder)->cumulative_frequency[state].high = bytesToUInt16_bigEndian(high_p); (*ariCoder)->cumulative_frequency[state].state = state; p = state_p + sizeof(uint16_t); } offset = 2*sizeof(int)+sizeof(uint64_t)+(*ariCoder)->numOfValidStates*6; } else { for(i=0;icumulative_frequency[state].low = bytesToUInt16_bigEndian(low_p); (*ariCoder)->cumulative_frequency[state].high = bytesToUInt16_bigEndian(high_p); (*ariCoder)->cumulative_frequency[state].state = state; p = state_p + sizeof(uint32_t); } offset = 2*sizeof(int)+sizeof(uint64_t)+(*ariCoder)->numOfValidStates*8; } } else if(total_frequency <=4294967296) { if(numOfRealStates<=256) { for(i=0;icumulative_frequency[state].low = bytesToUInt32_bigEndian(low_p); (*ariCoder)->cumulative_frequency[state].high = bytesToUInt32_bigEndian(high_p); (*ariCoder)->cumulative_frequency[state].state = state; p = state_p + 1; } offset = 2*sizeof(int)+sizeof(uint64_t)+(*ariCoder)->numOfValidStates*9; } else if(numOfRealStates<=65536) { for(i=0;icumulative_frequency[state].low = bytesToUInt32_bigEndian(low_p); (*ariCoder)->cumulative_frequency[state].high = bytesToUInt32_bigEndian(high_p); (*ariCoder)->cumulative_frequency[state].state = state; p = state_p + sizeof(uint16_t); } offset = 2*sizeof(int)+sizeof(uint64_t)+(*ariCoder)->numOfValidStates*10; } else { for(i=0;icumulative_frequency[state].low = bytesToUInt32_bigEndian(low_p); (*ariCoder)->cumulative_frequency[state].high = bytesToUInt32_bigEndian(high_p); (*ariCoder)->cumulative_frequency[state].state = state; p = state_p + sizeof(uint32_t); } offset = 2*sizeof(int)+sizeof(uint64_t)+(*ariCoder)->numOfValidStates*12; } } else { if(numOfRealStates<=256) { for(i=0;icumulative_frequency[state].low = bytesToUInt64_bigEndian(low_p); (*ariCoder)->cumulative_frequency[state].high = bytesToUInt64_bigEndian(high_p); (*ariCoder)->cumulative_frequency[state].state = state; p = state_p + 1; } offset = 2*sizeof(int)+sizeof(uint64_t)+(*ariCoder)->numOfValidStates*17; } else if(numOfRealStates<=65536) { for(i=0;icumulative_frequency[state].low = bytesToUInt64_bigEndian(low_p); (*ariCoder)->cumulative_frequency[state].high = bytesToUInt64_bigEndian(high_p); (*ariCoder)->cumulative_frequency[state].state = state; p = state_p + sizeof(uint16_t); } offset = 2*sizeof(int)+sizeof(uint64_t)+(*ariCoder)->numOfValidStates*18; } else { for(i=0;icumulative_frequency[state].low = bytesToUInt64_bigEndian(low_p); (*ariCoder)->cumulative_frequency[state].high = bytesToUInt64_bigEndian(high_p); (*ariCoder)->cumulative_frequency[state].state = state; p = state_p + sizeof(uint32_t); } offset = 2*sizeof(int)+sizeof(uint64_t)+(*ariCoder)->numOfValidStates*20; } } return offset; } /** * Arithmetic Encoding * @param AriCoder *ariCoder (input) * @param int *s (input) * @param size_t length (input) * @param unsigned char *out (output) * @param size_t *outSize (output) * * */ void ari_encode(AriCoder *ariCoder, int *s, size_t length, unsigned char *out, size_t *outSize) { int pending_bits = 0; size_t low = 0; size_t high = MAX_CODE; size_t i = 0, range = 0; size_t count = ariCoder->total_frequency; int c = 0, lackBits = 0; *outSize = 0; unsigned char *outp = out; Prob *cumulative_frequency = ariCoder->cumulative_frequency; unsigned int buf = 0; for (i=0;i= ONE_HALF ) { buf = output_bit_1_plus_pending(pending_bits); put_codes_to_output(buf, pending_bits+1, &outp, &lackBits, outSize); pending_bits = 0; } else if ( low >= ONE_FOURTH && high < THREE_FOURTHS ) { pending_bits++; low -= ONE_FOURTH; high -= ONE_FOURTH; } else break; high <<= 1; high++; low <<= 1; high &= MAX_CODE; low &= MAX_CODE; } } pending_bits++; if(low < ONE_FOURTH) { buf = output_bit_0_plus_pending(pending_bits); put_codes_to_output(buf, pending_bits+1, &outp, &lackBits, outSize); } else { buf = output_bit_1_plus_pending(pending_bits); put_codes_to_output(buf, pending_bits+1, &outp, &lackBits, outSize); } } /** * Get the integer code based on Arithmetic Coding Value * @param AriCoder *ariCoder (input) * @param size_t scaled_value (input) * * @return Prob* (output) * * */ Prob* getCode(AriCoder *ariCoder, size_t scaled_value) { int numOfRealStates = ariCoder->numOfRealStates; int i = 0; Prob *p = ariCoder->cumulative_frequency; for(i=0;ihigh) break; } return p; } /** * Get one bit from the input stream of bytes * @param unsigned char* p (input): the current location to be read (byte) of the byte stream * @param int offset (input): the offset of the specified byte in the byte stream * * @return unsigned char (output) : 1 or 0 * */ inline unsigned char get_bit(unsigned char* p, int offset) { return ((*p) >> (7-offset)) & 0x01; } /** * Arithmetic Decoding algorithm * @param AriCoder *ariCoder (input): the encoder with the constructed frequency information * @param unsigned char *s (input): the compressed stream of bytes * @param size_t s_len (input): the number of bytes in the 'unsigned char *s' * @param size_t targetLength (input): the target number of elements in the type array * @param int *out (output) : the result (type array decompressed from the stream 's') * * */ void ari_decode(AriCoder *ariCoder, unsigned char *s, size_t s_len, size_t targetLength, int *out) { size_t high = MAX_CODE; size_t low = 0, i = 0; size_t range = 0, scaled_value = 0; size_t total_frequency = ariCoder->total_frequency; unsigned char *sp = s+5; unsigned int offset = 4; size_t value = (bytesToUInt64_bigEndian(s) >> 20); //alignment with the MAX_CODE size_t s_counter = sizeof(int); for(i=0;itotal_frequency - 1 ) / range; Prob *p = getCode(ariCoder, scaled_value); out[i] = p->state; //output the state to the 'out' array high = low + (range*p->high)/total_frequency -1; low = low + (range*p->low)/total_frequency; for( ; ; ) { if (high < ONE_HALF) { //do nothing, bit is a zero } else if ( low >= ONE_HALF ) { value -= ONE_HALF; //subtract one half from all three code values low -= ONE_HALF; high -= ONE_HALF; } else if ( low >= ONE_FOURTH && high < THREE_FOURTHS ) { value -= ONE_FOURTH; low -= ONE_FOURTH; high -= ONE_FOURTH; } else break; low <<= 1; high <<= 1; high++; value <<= 1; //load one bit from the input byte stream if(s_counter < s_len) { value += get_bit(sp, offset++); if(offset==8) { sp++; s_counter++; offset = 0; } } } } }