未验证 提交 e538cd71 编写于 作者: A Amit D 提交者: GitHub

Merge pull request #3486 from stweil/tfloat

Add TFloat data type for neural network
......@@ -29,6 +29,28 @@ namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel AVX intrinsics to access the SIMD instruction set.
#if defined(FAST_FLOAT)
float DotProductAVX(const float *u, const float *v, int n) {
const unsigned quot = n / 8;
const unsigned rem = n % 8;
__m256 t0 = _mm256_setzero_ps();
for (unsigned k = 0; k < quot; k++) {
__m256 f0 = _mm256_loadu_ps(u);
__m256 f1 = _mm256_loadu_ps(v);
f0 = _mm256_mul_ps(f0, f1);
t0 = _mm256_add_ps(t0, f0);
u += 8;
v += 8;
}
alignas(32) float tmp[8];
_mm256_store_ps(tmp, t0);
float result = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7];
for (unsigned k = 0; k < rem; k++) {
result += *u++ * *v++;
}
return result;
}
#else
double DotProductAVX(const double *u, const double *v, int n) {
const unsigned quot = n / 8;
const unsigned rem = n % 8;
......@@ -57,6 +79,7 @@ double DotProductAVX(const double *u, const double *v, int n) {
}
return result;
}
#endif
} // namespace tesseract.
......
......@@ -29,6 +29,34 @@ namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel FMA intrinsics to access the SIMD instruction set.
#if defined(FAST_FLOAT)
float DotProductFMA(const float *u, const float *v, int n) {
const unsigned quot = n / 16;
const unsigned rem = n % 16;
__m256 t0 = _mm256_setzero_ps();
__m256 t1 = _mm256_setzero_ps();
for (unsigned k = 0; k < quot; k++) {
__m256 f0 = _mm256_loadu_ps(u);
__m256 f1 = _mm256_loadu_ps(v);
t0 = _mm256_fmadd_ps(f0, f1, t0);
u += 8;
v += 8;
__m256 f2 = _mm256_loadu_ps(u);
__m256 f3 = _mm256_loadu_ps(v);
t1 = _mm256_fmadd_ps(f2, f3, t1);
u += 8;
v += 8;
}
t0 = _mm256_hadd_ps(t0, t1);
alignas(32) float tmp[8];
_mm256_store_ps(tmp, t0);
float result = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7];
for (unsigned k = 0; k < rem; k++) {
result += *u++ * *v++;
}
return result;
}
#else
double DotProductFMA(const double *u, const double *v, int n) {
const unsigned quot = n / 8;
const unsigned rem = n % 8;
......@@ -55,6 +83,7 @@ double DotProductFMA(const double *u, const double *v, int n) {
}
return result;
}
#endif
} // namespace tesseract.
......
......@@ -30,6 +30,66 @@ namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
#if defined(FAST_FLOAT)
float DotProductSSE(const float *u, const float *v, int n) {
int max_offset = n - 4;
int offset = 0;
// Accumulate a set of 4 sums in sum, by loading pairs of 4 values from u and
// v, and multiplying them together in parallel.
__m128 sum = _mm_setzero_ps();
if (offset <= max_offset) {
offset = 4;
// Aligned load is reputedly faster but requires 16 byte aligned input.
if ((reinterpret_cast<uintptr_t>(u) & 15) == 0 &&
(reinterpret_cast<uintptr_t>(v) & 15) == 0) {
// Use aligned load.
sum = _mm_load_ps(u);
__m128 floats2 = _mm_load_ps(v);
// Multiply.
sum = _mm_mul_ps(sum, floats2);
while (offset <= max_offset) {
__m128 floats1 = _mm_load_ps(u + offset);
floats2 = _mm_load_ps(v + offset);
floats1 = _mm_mul_ps(floats1, floats2);
sum = _mm_add_ps(sum, floats1);
offset += 4;
}
} else {
// Use unaligned load.
sum = _mm_loadu_ps(u);
__m128 floats2 = _mm_loadu_ps(v);
// Multiply.
sum = _mm_mul_ps(sum, floats2);
while (offset <= max_offset) {
__m128 floats1 = _mm_loadu_ps(u + offset);
floats2 = _mm_loadu_ps(v + offset);
floats1 = _mm_mul_ps(floats1, floats2);
sum = _mm_add_ps(sum, floats1);
offset += 4;
}
}
}
// Add the 4 sums in sum horizontally.
#if 0
alignas(32) float tmp[4];
_mm_store_ps(tmp, sum);
float result = tmp[0] + tmp[1] + tmp[2] + tmp[3];
#else
__m128 zero = _mm_setzero_ps();
// https://www.felixcloutier.com/x86/haddps
sum = _mm_hadd_ps(sum, zero);
sum = _mm_hadd_ps(sum, zero);
// Extract the low result.
float result = _mm_cvtss_f32(sum);
#endif
// Add on any left-over products.
while (offset < n) {
result += u[offset] * v[offset];
++offset;
}
return result;
}
#else
double DotProductSSE(const double *u, const double *v, int n) {
int max_offset = n - 2;
int offset = 0;
......@@ -39,7 +99,8 @@ double DotProductSSE(const double *u, const double *v, int n) {
if (offset <= max_offset) {
offset = 2;
// Aligned load is reputedly faster but requires 16 byte aligned input.
if ((reinterpret_cast<uintptr_t>(u) & 15) == 0 && (reinterpret_cast<uintptr_t>(v) & 15) == 0) {
if ((reinterpret_cast<uintptr_t>(u) & 15) == 0 &&
(reinterpret_cast<uintptr_t>(v) & 15) == 0) {
// Use aligned load.
sum = _mm_load_pd(u);
__m128d floats2 = _mm_load_pd(v);
......@@ -78,6 +139,7 @@ double DotProductSSE(const double *u, const double *v, int n) {
}
return result;
}
#endif
} // namespace tesseract.
......
......@@ -115,7 +115,7 @@ struct TESS_API IntSimdMatrix {
static const IntSimdMatrix *intSimdMatrix;
// Only available with NEON.
static const IntSimdMatrix intSimdMatrixNEON;
// Only available with AVX2 / SSE.
// Only available with AVX2 / AVX / FMA / SSE.
static const IntSimdMatrix intSimdMatrixAVX2;
static const IntSimdMatrix intSimdMatrixSSE;
};
......
......@@ -15,14 +15,13 @@
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "intsimdmatrix.h"
#if !defined(__AVX2__)
# if defined(__i686__) || defined(__x86_64__)
# error Implementation only for AVX2 capable architectures
# endif
#else
# include "intsimdmatrix.h"
# include <immintrin.h>
# include <algorithm>
# include <cstdint>
......@@ -86,6 +85,243 @@ static inline __m128i load64_to_128(const int8_t *wi_) {
return _mm_set_epi64x(0, wi[0]);
}
#if defined(FAST_FLOAT)
static inline void ExtractResults8(__m256i result, const int8_t *wi,
const float *scales, float *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(result);
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
}
static inline void ExtractResults16(__m256i result0, __m256i result1,
const int8_t *&wi, const float *&scales,
float *&v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale =
_mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(result0);
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale01234567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res01234567 = _mm256_cvtepi32_ps(result1);
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v + 8, res01234567);
wi += 16;
scales += 16;
v += 16;
}
// Computes part of matrix.vector v = Wu. Computes N=64 results.
// The weights *must* be arranged so that consecutive reads from wi
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
// bias weights, before continuing with any more weights.
// u must be padded out with zeros to
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
__m256i result1 = _mm256_setzero_si256();
__m256i result2 = _mm256_setzero_si256();
__m256i result3 = _mm256_setzero_si256();
__m256i result4 = _mm256_setzero_si256();
__m256i result5 = _mm256_setzero_si256();
__m256i result6 = _mm256_setzero_si256();
__m256i result7 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
}
}
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
ExtractResults16(result4, result5, wi, scales, v);
ExtractResults16(result6, result7, wi, scales, v);
}
// Computes part of matrix.vector v = Wu. Computes N=32 results.
// For details see PartialMatrixDotVector64 with N=32.
static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
__m256i result1 = _mm256_setzero_si256();
__m256i result2 = _mm256_setzero_si256();
__m256i result3 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
}
}
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
}
// Computes part of matrix.vector v = Wu. Computes N=16 results.
// For details see PartialMatrixDotVector64 with N=16.
static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
__m256i result1 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
}
}
ExtractResults16(result0, result1, wi, scales, v);
}
// Computes part of matrix.vector v = Wu. Computes N=8 results.
// For details see PartialMatrixDotVector64 with N=8.
static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
}
}
ExtractResults8(result0, wi, scales, v);
}
static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
const int8_t *u, float *v) {
const int num_out = dim1;
const int num_in = dim2 - 1;
// Each call to a partial_func_ produces group_size outputs, except the
// last one, which can produce less.
const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
int output = 0;
int w_step = (rounded_num_in + 1) * group_size;
// Run with this group size, until it would produce too much output, then
// switch to a smaller size.
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
}
group_size /= 2;
w_step /= 2;
if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;
if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;
if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
}
}
#else
static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
double *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
......@@ -330,6 +566,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
}
}
#endif
const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
// Function.
......@@ -341,7 +578,8 @@ const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
// Number of 8 bit inputs in the inputs register.
kNumInputsPerRegister,
// Number of inputs in each weight group.
kNumInputsPerGroup};
kNumInputsPerGroup
};
} // namespace tesseract.
......
......@@ -96,7 +96,11 @@ bool SIMDDetect::sse_available_;
static TFloat DotProductAccelerate(const TFloat* u, const TFloat* v, int n) {
TFloat total = 0;
const int stride = 1;
#if defined(FAST_FLOAT)
vDSP_dotpr(u, stride, v, stride, &total, n);
#else
vDSP_dotprD(u, stride, v, stride, &total, n);
#endif
return total;
}
#endif
......
......@@ -740,8 +740,11 @@ public:
TO_ROW_IT row_it = &row_list;
for (row_it.mark_cycle_pt(); !row_it.cycled_list(); row_it.forward()) {
auto row = row_it.data();
tprintf("Row range (%g,%g), para_c=%g, blobcount=%" PRId32 "\n", row->min_y(), row->max_y(),
row->parallel_c(), row->blob_list()->length());
tprintf("Row range (%g,%g), para_c=%g, blobcount=%" PRId32 "\n",
static_cast<double>(row->min_y()),
static_cast<double>(row->max_y()),
static_cast<double>(row->parallel_c()),
row->blob_list()->length());
}
}
......
......@@ -417,7 +417,7 @@ public:
// Accumulates the element-wise sums of squares of src into *this.
void SumSquares(const GENERIC_2D_ARRAY<T> &src, const T &decay_factor) {
T update_factor = 1.0 - decay_factor;
T update_factor = 1 - decay_factor;
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] = array_[i] * decay_factor + update_factor * src.array_[i] * src.array_[i];
......
......@@ -25,7 +25,11 @@ namespace tesseract {
using TDimension = int16_t;
// Floating point data type used for LSTM calculations.
#if defined(FAST_FLOAT)
using TFloat = float;
#else
using TFloat = double;
#endif
}
......
......@@ -21,7 +21,7 @@
#include "intsimdmatrix.h"
#include "simddetect.h" // for DotProduct
#include "statistc.h"
#include "tprintf.h"
#include "tprintf.h" // forTFloat
namespace tesseract {
......@@ -36,8 +36,22 @@ const int kAdamCorrectionIterations = 200000;
// Epsilon in Adam to prevent division by zero.
const TFloat kAdamEpsilon = 1e-8;
// Utility function converts an array of float to the corresponding array
// of double.
// Utility functions convert between double and float arrays.
#ifdef FAST_FLOAT
static void DoubleToFloat(const GENERIC_2D_ARRAY<double> &src, GENERIC_2D_ARRAY<float> &dst) {
const auto dim1 = src.dim1();
const auto dim2 = src.dim2();
dst.ResizeNoInit(dim1, dim2);
for (int i = 0; i < dim1; ++i) {
const auto *src_i = src[i];
auto *dst_i = dst[i];
for (int j = 0; j < dim2; ++j) {
dst_i[j] = static_cast<float>(src_i[j]);
}
}
}
#endif
static void FloatToDouble(const GENERIC_2D_ARRAY<float> &src, GENERIC_2D_ARRAY<double> &dst) {
const auto dim1 = src.dim1();
const auto dim2 = src.dim2();
......@@ -51,6 +65,29 @@ static void FloatToDouble(const GENERIC_2D_ARRAY<float> &src, GENERIC_2D_ARRAY<d
}
}
static bool DeSerialize(TFile *fp, GENERIC_2D_ARRAY<TFloat> &tfloat_array) {
#ifdef FAST_FLOAT
GENERIC_2D_ARRAY<double> double_array;
if (!double_array.DeSerialize(fp)) {
return false;
}
DoubleToFloat(double_array, tfloat_array);
return true;
#else
return tfloat_array.DeSerialize(fp);
#endif
}
static bool Serialize(TFile *fp, const GENERIC_2D_ARRAY<TFloat> &tfloat_array) {
#ifdef FAST_FLOAT
GENERIC_2D_ARRAY<double> double_array;
FloatToDouble(tfloat_array, double_array);
return double_array.Serialize(fp);
#else
return tfloat_array.Serialize(fp);
#endif
}
// Computes matrix.vector v = Wu.
// u is of size W.dim2() - add_bias_fwd and the output v is of size
// W.dim1() - skip_bias_back.
......@@ -209,29 +246,30 @@ bool WeightMatrix::Serialize(bool training, TFile *fp) const {
if (!wi_.Serialize(fp)) {
return false;
}
// The scales stored in memory have an extra factor applied to them
// to allow faster operation. We have to remove that factor here
// before writing to disc.
auto scales = scales_;
for (auto &scale : scales) {
scale *= INT8_MAX;
}
uint32_t size = scales.size();
uint32_t size = scales_.size();
if (!fp->Serialize(&size)) {
return false;
}
if (!fp->Serialize(&scales[0], size)) {
return false;
for (auto scale : scales_) {
// The scales stored in memory have an extra factor applied to them
// to allow faster operation. We have to remove that factor here
// before writing to disc.
double value = scale * INT8_MAX;
if (!fp->Serialize(&value)) {
return false;
}
}
} else {
if (!wf_.Serialize(fp)) {
if (!tesseract::Serialize(fp, wf_)) {
return false;
}
if (training && !updates_.Serialize(fp)) {
return false;
}
if (training && use_adam_ && !dw_sq_sum_.Serialize(fp)) {
return false;
if (training) {
if (!tesseract::Serialize(fp, updates_)) {
return false;
}
if (use_adam_ && !tesseract::Serialize(fp, dw_sq_sum_)) {
return false;
}
}
}
return true;
......@@ -257,6 +295,16 @@ bool WeightMatrix::DeSerialize(bool training, TFile *fp) {
if (!fp->DeSerialize(&size)) {
return false;
}
#ifdef FAST_FLOAT
scales_.reserve(size);
for (auto n = size; n > 0; n--) {
double val;
if (!fp->DeSerialize(&val)) {
return false;
}
scales_.push_back(val / INT8_MAX);
}
#else
scales_.resize(size);
if (!fp->DeSerialize(&scales_[0], size)) {
return false;
......@@ -264,22 +312,25 @@ bool WeightMatrix::DeSerialize(bool training, TFile *fp) {
for (auto &scale : scales_) {
scale /= INT8_MAX;
}
#endif
if (IntSimdMatrix::intSimdMatrix) {
int32_t rounded_num_out;
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, rounded_num_out);
scales_.resize(rounded_num_out);
}
} else {
if (!wf_.DeSerialize(fp)) {
if (!tesseract::DeSerialize(fp, wf_)) {
return false;
}
if (training) {
InitBackward();
if (!updates_.DeSerialize(fp)) {
if (!tesseract::DeSerialize(fp, updates_)) {
return false;
}
if (use_adam_ && !dw_sq_sum_.DeSerialize(fp)) {
return false;
if (use_adam_) {
if (!tesseract::DeSerialize(fp, dw_sq_sum_)) {
return false;
}
}
}
}
......@@ -289,7 +340,11 @@ bool WeightMatrix::DeSerialize(bool training, TFile *fp) {
// As DeSerialize, but reads an old (float) format WeightMatrix for
// backward compatibility.
bool WeightMatrix::DeSerializeOld(bool training, TFile *fp) {
GENERIC_2D_ARRAY<float> float_array;
#ifdef FAST_FLOAT
// Not implemented.
ASSERT_HOST(!"not implemented");
return false;
#else
if (int_mode_) {
if (!wi_.DeSerialize(fp)) {
return false;
......@@ -303,6 +358,7 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile *fp) {
scales_.push_back(old_scale);
}
} else {
GENERIC_2D_ARRAY<float> float_array;
if (!float_array.DeSerialize(fp)) {
return false;
}
......@@ -310,6 +366,7 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile *fp) {
}
if (training) {
InitBackward();
GENERIC_2D_ARRAY<float> float_array;
if (!float_array.DeSerialize(fp)) {
return false;
}
......@@ -320,6 +377,7 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile *fp) {
}
}
return true;
#endif
}
// Computes matrix.vector v = Wu.
......
......@@ -52,8 +52,8 @@ protected:
return v;
}
// Makes a random scales vector of the given size.
std::vector<double> RandomScales(int size) {
std::vector<double> v(size);
std::vector<TFloat> RandomScales(int size) {
std::vector<TFloat> v(size);
for (int i = 0; i < size; ++i) {
v[i] = (1.0 + random_.SignedRand(1.0)) / INT8_MAX;
}
......@@ -61,19 +61,19 @@ protected:
}
// Tests a range of sizes and compares the results against the generic version.
void ExpectEqualResults(const IntSimdMatrix &matrix) {
double total = 0.0;
TFloat total = 0.0;
for (int num_out = 1; num_out < 130; ++num_out) {
for (int num_in = 1; num_in < 130; ++num_in) {
GENERIC_2D_ARRAY<int8_t> w = InitRandom(num_out, num_in + 1);
std::vector<int8_t> u = RandomVector(num_in, matrix);
std::vector<double> scales = RandomScales(num_out);
std::vector<TFloat> scales = RandomScales(num_out);
int ro = num_out;
if (IntSimdMatrix::intSimdMatrix) {
ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
}
std::vector<double> base_result(num_out);
std::vector<TFloat> base_result(num_out);
IntSimdMatrix::MatrixDotVector(w, scales, u.data(), base_result.data());
std::vector<double> test_result(ro);
std::vector<TFloat> test_result(ro);
std::vector<int8_t> shaped_wi;
int32_t rounded_num_out;
matrix.Init(w, shaped_wi, rounded_num_out);
......@@ -91,7 +91,11 @@ protected:
}
}
// Compare sum of all results with expected value.
#ifdef FAST_FLOAT
EXPECT_FLOAT_EQ(total, 337852.16f);
#else
EXPECT_FLOAT_EQ(total, 337849.39354684710);
#endif
}
TRand random_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册