提交 c1c1e426 编写于 作者: R Ray Smith

Added new LSTM-based neural network line recognizer

上级 5d21ecfa
...@@ -16,7 +16,7 @@ endif ...@@ -16,7 +16,7 @@ endif
.PHONY: install-langs ScrollView.jar install-jars training .PHONY: install-langs ScrollView.jar install-jars training
SUBDIRS = ccutil viewer cutil opencl ccstruct dict classify wordrec textord SUBDIRS = arch ccutil viewer cutil opencl ccstruct dict classify wordrec textord lstm
if !NO_CUBE_BUILD if !NO_CUBE_BUILD
SUBDIRS += neural_networks/runtime cube SUBDIRS += neural_networks/runtime cube
endif endif
......
AM_CPPFLAGS += -DLOCALEDIR=\"$(localedir)\"\ AM_CPPFLAGS += -DLOCALEDIR=\"$(localedir)\"\
-DUSE_STD_NAMESPACE \ -DUSE_STD_NAMESPACE \
-I$(top_srcdir)/arch -I$(top_srcdir)/lstm \
-I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct -I$(top_srcdir)/cube \ -I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct -I$(top_srcdir)/cube \
-I$(top_srcdir)/viewer \ -I$(top_srcdir)/viewer \
-I$(top_srcdir)/textord -I$(top_srcdir)/dict \ -I$(top_srcdir)/textord -I$(top_srcdir)/dict \
...@@ -27,6 +28,9 @@ libtesseract_api_la_LIBADD = \ ...@@ -27,6 +28,9 @@ libtesseract_api_la_LIBADD = \
../wordrec/libtesseract_wordrec.la \ ../wordrec/libtesseract_wordrec.la \
../classify/libtesseract_classify.la \ ../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \ ../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \ ../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \ ../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \ ../viewer/libtesseract_viewer.la \
...@@ -57,6 +61,9 @@ libtesseract_la_LIBADD = \ ...@@ -57,6 +61,9 @@ libtesseract_la_LIBADD = \
../wordrec/libtesseract_wordrec.la \ ../wordrec/libtesseract_wordrec.la \
../classify/libtesseract_classify.la \ ../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \ ../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \ ../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \ ../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \ ../viewer/libtesseract_viewer.la \
......
...@@ -121,7 +121,6 @@ TessBaseAPI::TessBaseAPI() ...@@ -121,7 +121,6 @@ TessBaseAPI::TessBaseAPI()
block_list_(NULL), block_list_(NULL),
page_res_(NULL), page_res_(NULL),
input_file_(NULL), input_file_(NULL),
input_image_(NULL),
output_file_(NULL), output_file_(NULL),
datapath_(NULL), datapath_(NULL),
language_(NULL), language_(NULL),
...@@ -515,9 +514,7 @@ void TessBaseAPI::ClearAdaptiveClassifier() { ...@@ -515,9 +514,7 @@ void TessBaseAPI::ClearAdaptiveClassifier() {
/** /**
* Provide an image for Tesseract to recognize. Format is as * Provide an image for Tesseract to recognize. Format is as
* TesseractRect above. Does not copy the image buffer, or take * TesseractRect above. Copies the image buffer and converts to Pix.
* ownership. The source image may be destroyed after Recognize is called,
* either explicitly or implicitly via one of the Get*Text functions.
* SetImage clears all recognition results, and sets the rectangle to the * SetImage clears all recognition results, and sets the rectangle to the
* full image, so it may be followed immediately by a GetUTF8Text, and it * full image, so it may be followed immediately by a GetUTF8Text, and it
* will automatically perform recognition. * will automatically perform recognition.
...@@ -525,9 +522,11 @@ void TessBaseAPI::ClearAdaptiveClassifier() { ...@@ -525,9 +522,11 @@ void TessBaseAPI::ClearAdaptiveClassifier() {
void TessBaseAPI::SetImage(const unsigned char* imagedata, void TessBaseAPI::SetImage(const unsigned char* imagedata,
int width, int height, int width, int height,
int bytes_per_pixel, int bytes_per_line) { int bytes_per_pixel, int bytes_per_line) {
if (InternalSetImage()) if (InternalSetImage()) {
thresholder_->SetImage(imagedata, width, height, thresholder_->SetImage(imagedata, width, height,
bytes_per_pixel, bytes_per_line); bytes_per_pixel, bytes_per_line);
SetInputImage(thresholder_->GetPixRect());
}
} }
void TessBaseAPI::SetSourceResolution(int ppi) { void TessBaseAPI::SetSourceResolution(int ppi) {
...@@ -539,18 +538,17 @@ void TessBaseAPI::SetSourceResolution(int ppi) { ...@@ -539,18 +538,17 @@ void TessBaseAPI::SetSourceResolution(int ppi) {
/** /**
* Provide an image for Tesseract to recognize. As with SetImage above, * Provide an image for Tesseract to recognize. As with SetImage above,
* Tesseract doesn't take a copy or ownership or pixDestroy the image, so * Tesseract takes its own copy of the image, so it need not persist until
* it must persist until after Recognize. * after Recognize.
* Pix vs raw, which to use? * Pix vs raw, which to use?
* Use Pix where possible. A future version of Tesseract may choose to use Pix * Use Pix where possible. Tesseract uses Pix as its internal representation
* as its internal representation and discard IMAGE altogether. * and it is therefore more efficient to provide a Pix directly.
* Because of that, an implementation that sources and targets Pix may end up
* with less copies than an implementation that does not.
*/ */
void TessBaseAPI::SetImage(Pix* pix) { void TessBaseAPI::SetImage(Pix* pix) {
if (InternalSetImage()) if (InternalSetImage()) {
thresholder_->SetImage(pix); thresholder_->SetImage(pix);
SetInputImage(pix); SetInputImage(thresholder_->GetPixRect());
}
} }
/** /**
...@@ -693,8 +691,8 @@ Boxa* TessBaseAPI::GetComponentImages(PageIteratorLevel level, ...@@ -693,8 +691,8 @@ Boxa* TessBaseAPI::GetComponentImages(PageIteratorLevel level,
if (pixa != NULL) { if (pixa != NULL) {
Pix* pix = NULL; Pix* pix = NULL;
if (raw_image) { if (raw_image) {
pix = page_it->GetImage(level, raw_padding, input_image_, pix = page_it->GetImage(level, raw_padding, GetInputImage(), &left,
&left, &top); &top);
} else { } else {
pix = page_it->GetBinaryImage(level); pix = page_it->GetBinaryImage(level);
} }
...@@ -849,13 +847,17 @@ int TessBaseAPI::Recognize(ETEXT_DESC* monitor) { ...@@ -849,13 +847,17 @@ int TessBaseAPI::Recognize(ETEXT_DESC* monitor) {
} else if (tesseract_->tessedit_resegment_from_boxes) { } else if (tesseract_->tessedit_resegment_from_boxes) {
page_res_ = tesseract_->ApplyBoxes(*input_file_, false, block_list_); page_res_ = tesseract_->ApplyBoxes(*input_file_, false, block_list_);
} else { } else {
// TODO(rays) LSTM here. page_res_ = new PAGE_RES(tesseract_->AnyLSTMLang(),
page_res_ = new PAGE_RES(false,
block_list_, &tesseract_->prev_word_best_choice_); block_list_, &tesseract_->prev_word_best_choice_);
} }
if (page_res_ == NULL) { if (page_res_ == NULL) {
return -1; return -1;
} }
if (tesseract_->tessedit_train_line_recognizer) {
tesseract_->TrainLineRecognizer(*input_file_, *output_file_, block_list_);
tesseract_->CorrectClassifyWords(page_res_);
return 0;
}
if (tesseract_->tessedit_make_boxes_from_boxes) { if (tesseract_->tessedit_make_boxes_from_boxes) {
tesseract_->CorrectClassifyWords(page_res_); tesseract_->CorrectClassifyWords(page_res_);
return 0; return 0;
...@@ -938,17 +940,10 @@ int TessBaseAPI::RecognizeForChopTest(ETEXT_DESC* monitor) { ...@@ -938,17 +940,10 @@ int TessBaseAPI::RecognizeForChopTest(ETEXT_DESC* monitor) {
return 0; return 0;
} }
void TessBaseAPI::SetInputImage(Pix *pix) { // Takes ownership of the input pix.
if (input_image_) void TessBaseAPI::SetInputImage(Pix* pix) { tesseract_->set_pix_original(pix); }
pixDestroy(&input_image_);
input_image_ = NULL;
if (pix)
input_image_ = pixCopy(NULL, pix);
}
Pix* TessBaseAPI::GetInputImage() { Pix* TessBaseAPI::GetInputImage() { return tesseract_->pix_original(); }
return input_image_;
}
const char * TessBaseAPI::GetInputName() { const char * TessBaseAPI::GetInputName() {
if (input_file_) if (input_file_)
...@@ -992,8 +987,7 @@ bool TessBaseAPI::ProcessPagesFileList(FILE *flist, ...@@ -992,8 +987,7 @@ bool TessBaseAPI::ProcessPagesFileList(FILE *flist,
} }
// Begin producing output // Begin producing output
const char* kUnknownTitle = ""; if (renderer && !renderer->BeginDocument(unknown_title_)) {
if (renderer && !renderer->BeginDocument(kUnknownTitle)) {
return false; return false;
} }
...@@ -1105,7 +1099,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename, ...@@ -1105,7 +1099,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename,
const char* retry_config, const char* retry_config,
int timeout_millisec, int timeout_millisec,
TessResultRenderer* renderer) { TessResultRenderer* renderer) {
#ifndef ANDROID_BUILD
PERF_COUNT_START("ProcessPages") PERF_COUNT_START("ProcessPages")
bool stdInput = !strcmp(filename, "stdin") || !strcmp(filename, "-"); bool stdInput = !strcmp(filename, "stdin") || !strcmp(filename, "-");
if (stdInput) { if (stdInput) {
...@@ -1162,8 +1155,7 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename, ...@@ -1162,8 +1155,7 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename,
} }
// Begin the output // Begin the output
const char* kUnknownTitle = ""; if (renderer && !renderer->BeginDocument(unknown_title_)) {
if (renderer && !renderer->BeginDocument(kUnknownTitle)) {
pixDestroy(&pix); pixDestroy(&pix);
return false; return false;
} }
...@@ -1185,9 +1177,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename, ...@@ -1185,9 +1177,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename,
} }
PERF_COUNT_END PERF_COUNT_END
return true; return true;
#else
return false;
#endif
} }
bool TessBaseAPI::ProcessPage(Pix* pix, int page_index, const char* filename, bool TessBaseAPI::ProcessPage(Pix* pix, int page_index, const char* filename,
...@@ -2107,10 +2096,6 @@ void TessBaseAPI::End() { ...@@ -2107,10 +2096,6 @@ void TessBaseAPI::End() {
delete input_file_; delete input_file_;
input_file_ = NULL; input_file_ = NULL;
} }
if (input_image_ != NULL) {
pixDestroy(&input_image_);
input_image_ = NULL;
}
if (output_file_ != NULL) { if (output_file_ != NULL) {
delete output_file_; delete output_file_;
output_file_ = NULL; output_file_ = NULL;
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#ifndef TESSERACT_API_BASEAPI_H__ #ifndef TESSERACT_API_BASEAPI_H__
#define TESSERACT_API_BASEAPI_H__ #define TESSERACT_API_BASEAPI_H__
#define TESSERACT_VERSION_STR "3.05.00dev" #define TESSERACT_VERSION_STR "4.00.00alpha"
#define TESSERACT_VERSION 0x030500 #define TESSERACT_VERSION 0x040000
#define MAKE_VERSION(major, minor, patch) (((major) << 16) | ((minor) << 8) | \ #define MAKE_VERSION(major, minor, patch) (((major) << 16) | ((minor) << 8) | \
(patch)) (patch))
...@@ -142,6 +142,7 @@ class TESS_API TessBaseAPI { ...@@ -142,6 +142,7 @@ class TESS_API TessBaseAPI {
* is stored in the PDF so we need that as well. * is stored in the PDF so we need that as well.
*/ */
const char* GetInputName(); const char* GetInputName();
// Takes ownership of the input pix.
void SetInputImage(Pix *pix); void SetInputImage(Pix *pix);
Pix* GetInputImage(); Pix* GetInputImage();
int GetSourceYResolution(); int GetSourceYResolution();
...@@ -333,9 +334,7 @@ class TESS_API TessBaseAPI { ...@@ -333,9 +334,7 @@ class TESS_API TessBaseAPI {
/** /**
* Provide an image for Tesseract to recognize. Format is as * Provide an image for Tesseract to recognize. Format is as
* TesseractRect above. Does not copy the image buffer, or take * TesseractRect above. Copies the image buffer and converts to Pix.
* ownership. The source image may be destroyed after Recognize is called,
* either explicitly or implicitly via one of the Get*Text functions.
* SetImage clears all recognition results, and sets the rectangle to the * SetImage clears all recognition results, and sets the rectangle to the
* full image, so it may be followed immediately by a GetUTF8Text, and it * full image, so it may be followed immediately by a GetUTF8Text, and it
* will automatically perform recognition. * will automatically perform recognition.
...@@ -345,13 +344,11 @@ class TESS_API TessBaseAPI { ...@@ -345,13 +344,11 @@ class TESS_API TessBaseAPI {
/** /**
* Provide an image for Tesseract to recognize. As with SetImage above, * Provide an image for Tesseract to recognize. As with SetImage above,
* Tesseract doesn't take a copy or ownership or pixDestroy the image, so * Tesseract takes its own copy of the image, so it need not persist until
* it must persist until after Recognize. * after Recognize.
* Pix vs raw, which to use? * Pix vs raw, which to use?
* Use Pix where possible. A future version of Tesseract may choose to use Pix * Use Pix where possible. Tesseract uses Pix as its internal representation
* as its internal representation and discard IMAGE altogether. * and it is therefore more efficient to provide a Pix directly.
* Because of that, an implementation that sources and targets Pix may end up
* with less copies than an implementation that does not.
*/ */
void SetImage(Pix* pix); void SetImage(Pix* pix);
...@@ -866,7 +863,6 @@ class TESS_API TessBaseAPI { ...@@ -866,7 +863,6 @@ class TESS_API TessBaseAPI {
BLOCK_LIST* block_list_; ///< The page layout. BLOCK_LIST* block_list_; ///< The page layout.
PAGE_RES* page_res_; ///< The page-level data. PAGE_RES* page_res_; ///< The page-level data.
STRING* input_file_; ///< Name used by training code. STRING* input_file_; ///< Name used by training code.
Pix* input_image_; ///< Image used for searchable PDF
STRING* output_file_; ///< Name used by debug code. STRING* output_file_; ///< Name used by debug code.
STRING* datapath_; ///< Current location of tessdata. STRING* datapath_; ///< Current location of tessdata.
STRING* language_; ///< Last initialized language. STRING* language_; ///< Last initialized language.
...@@ -902,6 +898,12 @@ class TESS_API TessBaseAPI { ...@@ -902,6 +898,12 @@ class TESS_API TessBaseAPI {
int timeout_millisec, int timeout_millisec,
TessResultRenderer* renderer, TessResultRenderer* renderer,
int tessedit_page_number); int tessedit_page_number);
// There's currently no way to pass a document title from the
// Tesseract command line, and we have multiple places that choose
// to set the title to an empty string. Using a single named
// variable will hopefully reduce confusion if the situation changes
// in the future.
const char *unknown_title_ = "";
}; // class TessBaseAPI. }; // class TessBaseAPI.
/** Escape a char string - remove &<>"' with HTML codes. */ /** Escape a char string - remove &<>"' with HTML codes. */
......
...@@ -620,7 +620,6 @@ bool TessPDFRenderer::BeginDocumentHandler() { ...@@ -620,7 +620,6 @@ bool TessPDFRenderer::BeginDocumentHandler() {
AppendPDFObject(buf); AppendPDFObject(buf);
// FONT DESCRIPTOR // FONT DESCRIPTOR
const int kCharHeight = 2; // Effect: highlights are half height
n = snprintf(buf, sizeof(buf), n = snprintf(buf, sizeof(buf),
"7 0 obj\n" "7 0 obj\n"
"<<\n" "<<\n"
...@@ -636,10 +635,10 @@ bool TessPDFRenderer::BeginDocumentHandler() { ...@@ -636,10 +635,10 @@ bool TessPDFRenderer::BeginDocumentHandler() {
" /Type /FontDescriptor\n" " /Type /FontDescriptor\n"
">>\n" ">>\n"
"endobj\n", "endobj\n",
1000 / kCharHeight, 1000,
1000 / kCharHeight, 1000,
1000 / kCharWidth, 1000 / kCharWidth,
1000 / kCharHeight, 1000,
8L // Font data 8L // Font data
); );
if (n >= sizeof(buf)) return false; if (n >= sizeof(buf)) return false;
......
...@@ -77,7 +77,7 @@ class TESS_API TessResultRenderer { ...@@ -77,7 +77,7 @@ class TESS_API TessResultRenderer {
bool EndDocument(); bool EndDocument();
const char* file_extension() const { return file_extension_; } const char* file_extension() const { return file_extension_; }
const char* title() const { return title_; } const char* title() const { return title_.c_str(); }
/** /**
* Returns the index of the last image given to AddImage * Returns the index of the last image given to AddImage
...@@ -126,7 +126,7 @@ class TESS_API TessResultRenderer { ...@@ -126,7 +126,7 @@ class TESS_API TessResultRenderer {
private: private:
const char* file_extension_; // standard extension for generated output const char* file_extension_; // standard extension for generated output
const char* title_; // title of document being renderered STRING title_; // title of document being renderered
int imagenum_; // index of last image added int imagenum_; // index of last image added
FILE* fout_; // output file pointer FILE* fout_; // output file pointer
......
AM_CPPFLAGS += -I$(top_srcdir)/ccutil
AUTOMAKE_OPTIONS = subdir-objects
SUBDIRS =
AM_CXXFLAGS =
if VISIBILITY
AM_CXXFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
AM_CPPFLAGS += -DTESS_EXPORTS
endif
include_HEADERS = \
dotproductavx.h dotproductsse.h
noinst_HEADERS =
if !USING_MULTIPLELIBS
noinst_LTLIBRARIES = libtesseract_avx.la libtesseract_sse.la
else
lib_LTLIBRARIES = libtesseract_avx.la libtesseract_sse.la
libtesseract_avx_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION)
libtesseract_sse_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION)
endif
libtesseract_avx_la_CXXFLAGS = -mavx
libtesseract_sse_la_CXXFLAGS = -msse4.1
libtesseract_avx_la_SOURCES = dotproductavx.cpp
libtesseract_sse_la_SOURCES = dotproductsse.cpp
///////////////////////////////////////////////////////////////////////
// File: dotproductavx.cpp
// Description: Architecture-specific dot-product function.
// Author: Ray Smith
// Created: Wed Jul 22 10:48:05 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#if !defined(__AVX__)
// Implementation for non-avx archs.
#include "dotproductavx.h"
#include <stdio.h>
#include <stdlib.h>
namespace tesseract {
double DotProductAVX(const double* u, const double* v, int n) {
fprintf(stderr, "DotProductAVX can't be used on Android\n");
abort();
}
} // namespace tesseract
#else // !defined(__AVX__)
// Implementation for avx capable archs.
#include <immintrin.h>
#include <stdint.h>
#include "dotproductavx.h"
#include "host.h"
namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel AVX intrinsics to access the SIMD instruction set.
double DotProductAVX(const double* u, const double* v, int n) {
int max_offset = n - 4;
int offset = 0;
// Accumulate a set of 4 sums in sum, by loading pairs of 4 values from u and
// v, and multiplying them together in parallel.
__m256d sum = _mm256_setzero_pd();
if (offset <= max_offset) {
offset = 4;
// Aligned load is reputedly faster but requires 32 byte aligned input.
if ((reinterpret_cast<const uintptr_t>(u) & 31) == 0 &&
(reinterpret_cast<const uintptr_t>(v) & 31) == 0) {
// Use aligned load.
__m256d floats1 = _mm256_load_pd(u);
__m256d floats2 = _mm256_load_pd(v);
// Multiply.
sum = _mm256_mul_pd(floats1, floats2);
while (offset <= max_offset) {
floats1 = _mm256_load_pd(u + offset);
floats2 = _mm256_load_pd(v + offset);
offset += 4;
__m256d product = _mm256_mul_pd(floats1, floats2);
sum = _mm256_add_pd(sum, product);
}
} else {
// Use unaligned load.
__m256d floats1 = _mm256_loadu_pd(u);
__m256d floats2 = _mm256_loadu_pd(v);
// Multiply.
sum = _mm256_mul_pd(floats1, floats2);
while (offset <= max_offset) {
floats1 = _mm256_loadu_pd(u + offset);
floats2 = _mm256_loadu_pd(v + offset);
offset += 4;
__m256d product = _mm256_mul_pd(floats1, floats2);
sum = _mm256_add_pd(sum, product);
}
}
}
// Add the 4 product sums together horizontally. Not so easy as with sse, as
// there is no add across the upper/lower 128 bit boundary, so permute to
// move the upper 128 bits to lower in another register.
__m256d sum2 = _mm256_permute2f128_pd(sum, sum, 1);
sum = _mm256_hadd_pd(sum, sum2);
sum = _mm256_hadd_pd(sum, sum);
double result;
// _mm256_extract_f64 doesn't exist, but resist the temptation to use an sse
// instruction, as that introduces a 70 cycle delay. All this casting is to
// fool the instrinsics into thinking we are extracting the bottom int64.
*(reinterpret_cast<inT64*>(&result)) =
_mm256_extract_epi64(_mm256_castpd_si256(sum), 0);
while (offset < n) {
result += u[offset] * v[offset];
++offset;
}
return result;
}
} // namespace tesseract.
#endif // ANDROID_BUILD
///////////////////////////////////////////////////////////////////////
// File: dotproductavx.h
// Description: Architecture-specific dot-product function.
// Author: Ray Smith
// Created: Wed Jul 22 10:51:05 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_ARCH_DOTPRODUCTAVX_H_
#define TESSERACT_ARCH_DOTPRODUCTAVX_H_
namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel AVX intrinsics to access the SIMD instruction set.
double DotProductAVX(const double* u, const double* v, int n);
} // namespace tesseract.
#endif // TESSERACT_ARCH_DOTPRODUCTAVX_H_
///////////////////////////////////////////////////////////////////////
// File: dotproductsse.cpp
// Description: Architecture-specific dot-product function.
// Author: Ray Smith
// Created: Wed Jul 22 10:57:45 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#if !defined(__SSE4_1__)
// This code can't compile with "-msse4.1", so use dummy stubs.
#include "dotproductsse.h"
#include <stdio.h>
#include <stdlib.h>
namespace tesseract {
double DotProductSSE(const double* u, const double* v, int n) {
fprintf(stderr, "DotProductSSE can't be used on Android\n");
abort();
}
inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n) {
fprintf(stderr, "IntDotProductSSE can't be used on Android\n");
abort();
}
} // namespace tesseract
#else // !defined(__SSE4_1__)
// Non-Android code here
#include <emmintrin.h>
#include <smmintrin.h>
#include <stdint.h>
#include "dotproductsse.h"
#include "host.h"
namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
double DotProductSSE(const double* u, const double* v, int n) {
int max_offset = n - 2;
int offset = 0;
// Accumulate a set of 2 sums in sum, by loading pairs of 2 values from u and
// v, and multiplying them together in parallel.
__m128d sum = _mm_setzero_pd();
if (offset <= max_offset) {
offset = 2;
// Aligned load is reputedly faster but requires 16 byte aligned input.
if ((reinterpret_cast<const uintptr_t>(u) & 15) == 0 &&
(reinterpret_cast<const uintptr_t>(v) & 15) == 0) {
// Use aligned load.
sum = _mm_load_pd(u);
__m128d floats2 = _mm_load_pd(v);
// Multiply.
sum = _mm_mul_pd(sum, floats2);
while (offset <= max_offset) {
__m128d floats1 = _mm_load_pd(u + offset);
floats2 = _mm_load_pd(v + offset);
offset += 2;
floats1 = _mm_mul_pd(floats1, floats2);
sum = _mm_add_pd(sum, floats1);
}
} else {
// Use unaligned load.
sum = _mm_loadu_pd(u);
__m128d floats2 = _mm_loadu_pd(v);
// Multiply.
sum = _mm_mul_pd(sum, floats2);
while (offset <= max_offset) {
__m128d floats1 = _mm_loadu_pd(u + offset);
floats2 = _mm_loadu_pd(v + offset);
offset += 2;
floats1 = _mm_mul_pd(floats1, floats2);
sum = _mm_add_pd(sum, floats1);
}
}
}
// Add the 2 sums in sum horizontally.
sum = _mm_hadd_pd(sum, sum);
// Extract the low result.
double result = _mm_cvtsd_f64(sum);
// Add on any left-over products.
while (offset < n) {
result += u[offset] * v[offset];
++offset;
}
return result;
}
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n) {
int max_offset = n - 8;
int offset = 0;
// Accumulate a set of 4 32-bit sums in sum, by loading 8 pairs of 8-bit
// values, extending to 16 bit, multiplying to make 32 bit results.
__m128i sum = _mm_setzero_si128();
if (offset <= max_offset) {
offset = 8;
__m128i packed1 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(u));
__m128i packed2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(v));
sum = _mm_cvtepi8_epi16(packed1);
packed2 = _mm_cvtepi8_epi16(packed2);
// The magic _mm_add_epi16 is perfect here. It multiplies 8 pairs of 16 bit
// ints to make 32 bit results, which are then horizontally added in pairs
// to make 4 32 bit results that still fit in a 128 bit register.
sum = _mm_madd_epi16(sum, packed2);
while (offset <= max_offset) {
packed1 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(u + offset));
packed2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(v + offset));
offset += 8;
packed1 = _mm_cvtepi8_epi16(packed1);
packed2 = _mm_cvtepi8_epi16(packed2);
packed1 = _mm_madd_epi16(packed1, packed2);
sum = _mm_add_epi32(sum, packed1);
}
}
// Sum the 4 packed 32 bit sums and extract the low result.
sum = _mm_hadd_epi32(sum, sum);
sum = _mm_hadd_epi32(sum, sum);
inT32 result = _mm_cvtsi128_si32(sum);
while (offset < n) {
result += u[offset] * v[offset];
++offset;
}
return result;
}
} // namespace tesseract.
#endif // ANDROID_BUILD
///////////////////////////////////////////////////////////////////////
// File: dotproductsse.h
// Description: Architecture-specific dot-product function.
// Author: Ray Smith
// Created: Wed Jul 22 10:57:05 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_ARCH_DOTPRODUCTSSE_H_
#define TESSERACT_ARCH_DOTPRODUCTSSE_H_
#include "host.h"
namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
double DotProductSSE(const double* u, const double* v, int n);
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n);
} // namespace tesseract.
#endif // TESSERACT_ARCH_DOTPRODUCTSSE_H_
AM_CPPFLAGS += \ AM_CPPFLAGS += \
-DUSE_STD_NAMESPACE \ -DUSE_STD_NAMESPACE \
-I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct \ -I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct \
-I$(top_srcdir)/arch -I$(top_srcdir)/lstm \
-I$(top_srcdir)/viewer \ -I$(top_srcdir)/viewer \
-I$(top_srcdir)/classify -I$(top_srcdir)/dict \ -I$(top_srcdir)/classify -I$(top_srcdir)/dict \
-I$(top_srcdir)/wordrec -I$(top_srcdir)/cutil \ -I$(top_srcdir)/wordrec -I$(top_srcdir)/cutil \
...@@ -33,6 +34,9 @@ libtesseract_main_la_LIBADD = \ ...@@ -33,6 +34,9 @@ libtesseract_main_la_LIBADD = \
../ccstruct/libtesseract_ccstruct.la \ ../ccstruct/libtesseract_ccstruct.la \
../viewer/libtesseract_viewer.la \ ../viewer/libtesseract_viewer.la \
../dict/libtesseract_dict.la \ ../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../classify/libtesseract_classify.la \ ../classify/libtesseract_classify.la \
../cutil/libtesseract_cutil.la \ ../cutil/libtesseract_cutil.la \
../opencl/libtesseract_opencl.la ../opencl/libtesseract_opencl.la
...@@ -44,7 +48,7 @@ endif ...@@ -44,7 +48,7 @@ endif
libtesseract_main_la_SOURCES = \ libtesseract_main_la_SOURCES = \
adaptions.cpp applybox.cpp control.cpp \ adaptions.cpp applybox.cpp control.cpp \
docqual.cpp equationdetect.cpp fixspace.cpp fixxht.cpp \ docqual.cpp equationdetect.cpp fixspace.cpp fixxht.cpp \
ltrresultiterator.cpp \ linerec.cpp ltrresultiterator.cpp \
osdetect.cpp output.cpp pageiterator.cpp pagesegmain.cpp \ osdetect.cpp output.cpp pageiterator.cpp pagesegmain.cpp \
pagewalk.cpp par_control.cpp paragraphs.cpp paramsd.cpp pgedit.cpp recogtraining.cpp \ pagewalk.cpp par_control.cpp paragraphs.cpp paramsd.cpp pgedit.cpp recogtraining.cpp \
reject.cpp resultiterator.cpp superscript.cpp \ reject.cpp resultiterator.cpp superscript.cpp \
......
...@@ -84,7 +84,12 @@ BOOL8 Tesseract::recog_interactive(PAGE_RES_IT* pr_it) { ...@@ -84,7 +84,12 @@ BOOL8 Tesseract::recog_interactive(PAGE_RES_IT* pr_it) {
WordData word_data(*pr_it); WordData word_data(*pr_it);
SetupWordPassN(2, &word_data); SetupWordPassN(2, &word_data);
classify_word_and_language(2, pr_it, &word_data); // LSTM doesn't run on pass2, but we want to run pass2 for tesseract.
if (lstm_recognizer_ == NULL) {
classify_word_and_language(2, pr_it, &word_data);
} else {
classify_word_and_language(1, pr_it, &word_data);
}
if (tessedit_debug_quality_metrics) { if (tessedit_debug_quality_metrics) {
WERD_RES* word_res = pr_it->word(); WERD_RES* word_res = pr_it->word();
word_char_quality(word_res, pr_it->row()->row, &char_qual, &good_char_qual); word_char_quality(word_res, pr_it->row()->row, &char_qual, &good_char_qual);
...@@ -218,16 +223,14 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor, ...@@ -218,16 +223,14 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor,
if (pass_n == 1) { if (pass_n == 1) {
monitor->progress = 70 * w / words->size(); monitor->progress = 70 * w / words->size();
if (monitor->progress_callback != NULL) { if (monitor->progress_callback != NULL) {
TBOX box = pr_it->word()->word->bounding_box(); TBOX box = pr_it->word()->word->bounding_box();
(*monitor->progress_callback)(monitor->progress, (*monitor->progress_callback)(monitor->progress, box.left(),
box.left(), box.right(), box.right(), box.top(), box.bottom());
box.top(), box.bottom());
} }
} else { } else {
monitor->progress = 70 + 30 * w / words->size(); monitor->progress = 70 + 30 * w / words->size();
if (monitor->progress_callback!=NULL) { if (monitor->progress_callback != NULL) {
(*monitor->progress_callback)(monitor->progress, (*monitor->progress_callback)(monitor->progress, 0, 0, 0, 0);
0, 0, 0, 0);
} }
} }
if (monitor->deadline_exceeded() || if (monitor->deadline_exceeded() ||
...@@ -252,7 +255,8 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor, ...@@ -252,7 +255,8 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor,
pr_it->forward(); pr_it->forward();
ASSERT_HOST(pr_it->word() != NULL); ASSERT_HOST(pr_it->word() != NULL);
bool make_next_word_fuzzy = false; bool make_next_word_fuzzy = false;
if (ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) { if (!AnyLSTMLang() &&
ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) {
// Needs to be setup again to see the new outlines in the chopped_word. // Needs to be setup again to see the new outlines in the chopped_word.
SetupWordPassN(pass_n, word); SetupWordPassN(pass_n, word);
} }
...@@ -297,6 +301,16 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res, ...@@ -297,6 +301,16 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res,
const TBOX* target_word_box, const TBOX* target_word_box,
const char* word_config, const char* word_config,
int dopasses) { int dopasses) {
// PSM_RAW_LINE is a special-case mode in which the layout analysis is
// completely ignored and LSTM is run on the raw image. There is no hope
// of running normal tesseract in this situation or of integrating output.
#ifndef ANDROID_BUILD
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY &&
tessedit_pageseg_mode == PSM_RAW_LINE) {
RecogRawLine(page_res);
return true;
}
#endif
PAGE_RES_IT page_res_it(page_res); PAGE_RES_IT page_res_it(page_res);
if (tessedit_minimal_rej_pass1) { if (tessedit_minimal_rej_pass1) {
...@@ -385,7 +399,7 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res, ...@@ -385,7 +399,7 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res,
// The next passes can only be run if tesseract has been used, as cube // The next passes can only be run if tesseract has been used, as cube
// doesn't set all the necessary outputs in WERD_RES. // doesn't set all the necessary outputs in WERD_RES.
if (AnyTessLang()) { if (AnyTessLang() && !AnyLSTMLang()) {
// ****************** Pass 3 ******************* // ****************** Pass 3 *******************
// Fix fuzzy spaces. // Fix fuzzy spaces.
set_global_loc_code(LOC_FUZZY_SPACE); set_global_loc_code(LOC_FUZZY_SPACE);
...@@ -1362,6 +1376,19 @@ void Tesseract::classify_word_pass1(const WordData& word_data, ...@@ -1362,6 +1376,19 @@ void Tesseract::classify_word_pass1(const WordData& word_data,
cube_word_pass1(block, row, *in_word); cube_word_pass1(block, row, *in_word);
return; return;
} }
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
if (!(*in_word)->odd_size) {
LSTMRecognizeWord(*block, row, *in_word, out_words);
if (!out_words->empty())
return; // Successful lstm recognition.
}
// Fall back to tesseract for failed words or odd words.
(*in_word)->SetupForRecognition(unicharset, this, BestPix(),
OEM_TESSERACT_ONLY, NULL,
classify_bln_numeric_mode,
textord_use_cjk_fp_model,
poly_allow_detailed_fx, row, block);
}
#endif #endif
WERD_RES* word = *in_word; WERD_RES* word = *in_word;
match_word_pass_n(1, word, row, block); match_word_pass_n(1, word, row, block);
...@@ -1496,10 +1523,6 @@ void Tesseract::classify_word_pass2(const WordData& word_data, ...@@ -1496,10 +1523,6 @@ void Tesseract::classify_word_pass2(const WordData& word_data,
WERD_RES** in_word, WERD_RES** in_word,
PointerVector<WERD_RES>* out_words) { PointerVector<WERD_RES>* out_words) {
// Return if we do not want to run Tesseract. // Return if we do not want to run Tesseract.
if (tessedit_ocr_engine_mode != OEM_TESSERACT_ONLY &&
tessedit_ocr_engine_mode != OEM_TESSERACT_CUBE_COMBINED &&
word_data.word->best_choice != NULL)
return;
if (tessedit_ocr_engine_mode == OEM_CUBE_ONLY) { if (tessedit_ocr_engine_mode == OEM_CUBE_ONLY) {
return; return;
} }
......
///////////////////////////////////////////////////////////////////////
// File: linerec.cpp
// Description: Top-level line-based recognition module for Tesseract.
// Author: Ray Smith
// Created: Thu May 02 09:47:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "tesseractclass.h"
#include "allheaders.h"
#include "boxread.h"
#include "imagedata.h"
#ifndef ANDROID_BUILD
#include "lstmrecognizer.h"
#include "recodebeam.h"
#endif
#include "ndminx.h"
#include "pageres.h"
#include "tprintf.h"
namespace tesseract {
// Arbitarary penalty for non-dictionary words.
// TODO(rays) How to learn this?
const float kNonDictionaryPenalty = 5.0f;
// Scale factor to make certainty more comparable to Tesseract.
const float kCertaintyScale = 7.0f;
// Worst acceptable certainty for a dictionary word.
const float kWorstDictCertainty = -25.0f;
// Generates training data for training a line recognizer, eg LSTM.
// Breaks the page into lines, according to the boxes, and writes them to a
// serialized DocumentData based on output_basename.
void Tesseract::TrainLineRecognizer(const STRING& input_imagename,
const STRING& output_basename,
BLOCK_LIST *block_list) {
STRING lstmf_name = output_basename + ".lstmf";
DocumentData images(lstmf_name);
if (applybox_page > 0) {
// Load existing document for the previous pages.
if (!images.LoadDocument(lstmf_name.string(), "eng", 0, 0, NULL)) {
tprintf("Failed to read training data from %s!\n", lstmf_name.string());
return;
}
}
GenericVector<TBOX> boxes;
GenericVector<STRING> texts;
// Get the boxes for this page, if there are any.
if (!ReadAllBoxes(applybox_page, false, input_imagename, &boxes, &texts, NULL,
NULL) ||
boxes.empty()) {
tprintf("Failed to read boxes from %s\n", input_imagename.string());
return;
}
TrainFromBoxes(boxes, texts, block_list, &images);
if (!images.SaveDocument(lstmf_name.string(), NULL)) {
tprintf("Failed to write training data to %s!\n", lstmf_name.string());
}
}
// Generates training data for training a line recognizer, eg LSTM.
// Breaks the boxes into lines, normalizes them, converts to ImageData and
// appends them to the given training_data.
void Tesseract::TrainFromBoxes(const GenericVector<TBOX>& boxes,
const GenericVector<STRING>& texts,
BLOCK_LIST *block_list,
DocumentData* training_data) {
int box_count = boxes.size();
// Process all the text lines in this page, as defined by the boxes.
int end_box = 0;
for (int start_box = 0; start_box < box_count; start_box = end_box) {
// Find the textline of boxes starting at start and their bounding box.
TBOX line_box = boxes[start_box];
STRING line_str = texts[start_box];
for (end_box = start_box + 1; end_box < box_count && texts[end_box] != "\t";
++end_box) {
line_box += boxes[end_box];
line_str += texts[end_box];
}
// Find the most overlapping block.
BLOCK* best_block = NULL;
int best_overlap = 0;
BLOCK_IT b_it(block_list);
for (b_it.mark_cycle_pt(); !b_it.cycled_list(); b_it.forward()) {
BLOCK* block = b_it.data();
if (block->poly_block() != NULL && !block->poly_block()->IsText())
continue; // Not a text block.
TBOX block_box = block->bounding_box();
block_box.rotate(block->re_rotation());
if (block_box.major_overlap(line_box)) {
TBOX overlap_box = line_box.intersection(block_box);
if (overlap_box.area() > best_overlap) {
best_overlap = overlap_box.area();
best_block = block;
}
}
}
ImageData* imagedata = NULL;
if (best_block == NULL) {
tprintf("No block overlapping textline: %s\n", line_str.string());
} else {
imagedata = GetLineData(line_box, boxes, texts, start_box, end_box,
*best_block);
}
if (imagedata != NULL)
training_data->AddPageToDocument(imagedata);
if (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
}
}
// Returns an Imagedata containing the image of the given box,
// and ground truth boxes/truth text if available in the input.
// The image is not normalized in any way.
ImageData* Tesseract::GetLineData(const TBOX& line_box,
const GenericVector<TBOX>& boxes,
const GenericVector<STRING>& texts,
int start_box, int end_box,
const BLOCK& block) {
TBOX revised_box;
ImageData* image_data = GetRectImage(line_box, block, kImagePadding,
&revised_box);
if (image_data == NULL) return NULL;
image_data->set_page_number(applybox_page);
// Copy the boxes and shift them so they are relative to the image.
FCOORD block_rotation(block.re_rotation().x(), -block.re_rotation().y());
ICOORD shift = -revised_box.botleft();
GenericVector<TBOX> line_boxes;
GenericVector<STRING> line_texts;
for (int b = start_box; b < end_box; ++b) {
TBOX box = boxes[b];
box.rotate(block_rotation);
box.move(shift);
line_boxes.push_back(box);
line_texts.push_back(texts[b]);
}
GenericVector<int> page_numbers;
page_numbers.init_to_size(line_boxes.size(), applybox_page);
image_data->AddBoxes(line_boxes, line_texts, page_numbers);
return image_data;
}
// Helper gets the image of a rectangle, using the block.re_rotation() if
// needed to get to the image, and rotating the result back to horizontal
// layout. (CJK characters will be on their left sides) The vertical text flag
// is set in the returned ImageData if the text was originally vertical, which
// can be used to invoke a different CJK recognition engine. The revised_box
// is also returned to enable calculation of output bounding boxes.
ImageData* Tesseract::GetRectImage(const TBOX& box, const BLOCK& block,
int padding, TBOX* revised_box) const {
TBOX wbox = box;
wbox.pad(padding, padding);
*revised_box = wbox;
// Number of clockwise 90 degree rotations needed to get back to tesseract
// coords from the clipped image.
int num_rotations = 0;
if (block.re_rotation().y() > 0.0f)
num_rotations = 1;
else if (block.re_rotation().x() < 0.0f)
num_rotations = 2;
else if (block.re_rotation().y() < 0.0f)
num_rotations = 3;
// Handle two cases automatically: 1 the box came from the block, 2 the box
// came from a box file, and refers to the image, which the block may not.
if (block.bounding_box().major_overlap(*revised_box))
revised_box->rotate(block.re_rotation());
// Now revised_box always refers to the image.
// BestPix is never colormapped, but may be of any depth.
Pix* pix = BestPix();
int width = pixGetWidth(pix);
int height = pixGetHeight(pix);
TBOX image_box(0, 0, width, height);
// Clip to image bounds;
*revised_box &= image_box;
if (revised_box->null_box()) return NULL;
Box* clip_box = boxCreate(revised_box->left(), height - revised_box->top(),
revised_box->width(), revised_box->height());
Pix* box_pix = pixClipRectangle(pix, clip_box, NULL);
if (box_pix == NULL) return NULL;
boxDestroy(&clip_box);
if (num_rotations > 0) {
Pix* rot_pix = pixRotateOrth(box_pix, num_rotations);
pixDestroy(&box_pix);
box_pix = rot_pix;
}
// Convert sub-8-bit images to 8 bit.
int depth = pixGetDepth(box_pix);
if (depth < 8) {
Pix* grey;
grey = pixConvertTo8(box_pix, false);
pixDestroy(&box_pix);
box_pix = grey;
}
bool vertical_text = false;
if (num_rotations > 0) {
// Rotated the clipped revised box back to internal coordinates.
FCOORD rotation(block.re_rotation().x(), -block.re_rotation().y());
revised_box->rotate(rotation);
if (num_rotations != 2)
vertical_text = true;
}
return new ImageData(vertical_text, box_pix);
}
#ifndef ANDROID_BUILD
// Top-level function recognizes a single raw line.
void Tesseract::RecogRawLine(PAGE_RES* page_res) {
PAGE_RES_IT it(page_res);
PointerVector<WERD_RES> words;
LSTMRecognizeWord(*it.block()->block, it.row()->row, it.word(), &words);
if (getDict().stopper_debug_level >= 1) {
for (int w = 0; w < words.size(); ++w) {
words[w]->DebugWordChoices(true, NULL);
}
}
it.ReplaceCurrentWord(&words);
}
// Recognizes a word or group of words, converting to WERD_RES in *words.
// Analogous to classify_word_pass1, but can handle a group of words as well.
void Tesseract::LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word,
PointerVector<WERD_RES>* words) {
TBOX word_box = word->word->bounding_box();
// Get the word image - no frills.
if (tessedit_pageseg_mode == PSM_SINGLE_WORD ||
tessedit_pageseg_mode == PSM_RAW_LINE) {
// In single word mode, use the whole image without any other row/word
// interpretation.
word_box = TBOX(0, 0, ImageWidth(), ImageHeight());
} else {
float baseline = row->base_line((word_box.left() + word_box.right()) / 2);
if (baseline + row->descenders() < word_box.bottom())
word_box.set_bottom(baseline + row->descenders());
if (baseline + row->x_height() + row->ascenders() > word_box.top())
word_box.set_top(baseline + row->x_height() + row->ascenders());
}
ImageData* im_data = GetRectImage(word_box, block, kImagePadding, &word_box);
if (im_data == NULL) return;
lstm_recognizer_->RecognizeLine(*im_data, true, classify_debug_level > 0,
kWorstDictCertainty / kCertaintyScale,
lstm_use_matrix, &unicharset, word_box, 2.0,
false, words);
delete im_data;
SearchWords(words);
}
// Apply segmentation search to the given set of words, within the constraints
// of the existing ratings matrix. If there is already a best_choice on a word
// leaves it untouched and just sets the done/accepted etc flags.
void Tesseract::SearchWords(PointerVector<WERD_RES>* words) {
// Run the segmentation search on the network outputs and make a BoxWord
// for each of the output words.
// If we drop a word as junk, then there is always a space in front of the
// next.
bool deleted_prev = false;
for (int w = 0; w < words->size(); ++w) {
WERD_RES* word = (*words)[w];
if (word->best_choice == NULL) {
// If we are using the beam search, the unicharset had better match!
word->SetupWordScript(unicharset);
WordSearch(word);
} else if (word->best_choice->unicharset() == &unicharset &&
!lstm_recognizer_->IsRecoding()) {
// We set up the word without using the dictionary, so set the permuter
// now, but we can only do it because the unicharsets match.
word->best_choice->set_permuter(
getDict().valid_word(*word->best_choice, true));
}
if (word->best_choice == NULL) {
// It is a dud.
words->remove(w);
--w;
deleted_prev = true;
} else {
// Set the best state.
for (int i = 0; i < word->best_choice->length(); ++i) {
int length = word->best_choice->state(i);
word->best_state.push_back(length);
}
word->tess_failed = false;
word->tess_accepted = true;
word->tess_would_adapt = false;
word->done = true;
word->tesseract = this;
float word_certainty = MIN(word->space_certainty,
word->best_choice->certainty());
word_certainty *= kCertaintyScale;
// Arbitrary ding factor for non-dictionary words.
if (!lstm_recognizer_->IsRecoding() &&
!Dict::valid_word_permuter(word->best_choice->permuter(), true))
word_certainty -= kNonDictionaryPenalty;
if (getDict().stopper_debug_level >= 1) {
tprintf("Best choice certainty=%g, space=%g, scaled=%g, final=%g\n",
word->best_choice->certainty(), word->space_certainty,
MIN(word->space_certainty, word->best_choice->certainty()) *
kCertaintyScale,
word_certainty);
word->best_choice->print();
}
// Discard words that are impossibly bad, but allow a bit more for
// dictionary words.
if (word_certainty >= RecodeBeamSearch::kMinCertainty ||
(word_certainty >= kWorstDictCertainty &&
Dict::valid_word_permuter(word->best_choice->permuter(), true))) {
word->best_choice->set_certainty(word_certainty);
if (deleted_prev) word->word->set_blanks(1);
} else {
if (getDict().stopper_debug_level >= 1) {
tprintf("Deleting word with certainty %g\n", word_certainty);
word->best_choice->print();
}
// It is a dud.
words->remove(w);
--w;
deleted_prev = true;
}
}
}
}
#endif // ANDROID_BUILD
} // namespace tesseract.
...@@ -220,6 +220,12 @@ bool LTRResultIterator::WordIsFromDictionary() const { ...@@ -220,6 +220,12 @@ bool LTRResultIterator::WordIsFromDictionary() const {
permuter == USER_DAWG_PERM; permuter == USER_DAWG_PERM;
} }
// Returns the number of blanks before the current word.
int LTRResultIterator::BlanksBeforeWord() const {
if (it_->word() == NULL) return 1;
return it_->word()->word->space();
}
// Returns true if the current word is numeric. // Returns true if the current word is numeric.
bool LTRResultIterator::WordIsNumeric() const { bool LTRResultIterator::WordIsNumeric() const {
if (it_->word() == NULL) return false; // Already at the end! if (it_->word() == NULL) return false; // Already at the end!
......
...@@ -124,6 +124,9 @@ class TESS_API LTRResultIterator : public PageIterator { ...@@ -124,6 +124,9 @@ class TESS_API LTRResultIterator : public PageIterator {
// Returns true if the current word was found in a dictionary. // Returns true if the current word was found in a dictionary.
bool WordIsFromDictionary() const; bool WordIsFromDictionary() const;
// Returns the number of blanks before the current word.
int BlanksBeforeWord() const;
// Returns true if the current word is numeric. // Returns true if the current word is numeric.
bool WordIsNumeric() const; bool WordIsNumeric() const;
......
...@@ -40,6 +40,9 @@ ...@@ -40,6 +40,9 @@
#include "efio.h" #include "efio.h"
#include "danerror.h" #include "danerror.h"
#include "globals.h" #include "globals.h"
#ifndef ANDROID_BUILD
#include "lstmrecognizer.h"
#endif
#include "tesseractclass.h" #include "tesseractclass.h"
#include "params.h" #include "params.h"
...@@ -214,6 +217,18 @@ bool Tesseract::init_tesseract_lang_data( ...@@ -214,6 +217,18 @@ bool Tesseract::init_tesseract_lang_data(
ASSERT_HOST(init_cube_objects(true, &tessdata_manager)); ASSERT_HOST(init_cube_objects(true, &tessdata_manager));
if (tessdata_manager_debug_level) if (tessdata_manager_debug_level)
tprintf("Loaded Cube with combiner\n"); tprintf("Loaded Cube with combiner\n");
} else if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
if (tessdata_manager.SeekToStart(TESSDATA_LSTM)) {
lstm_recognizer_ = new LSTMRecognizer;
TFile fp;
fp.Open(tessdata_manager.GetDataFilePtr(), -1);
ASSERT_HOST(lstm_recognizer_->DeSerialize(tessdata_manager.swap(), &fp));
if (lstm_use_matrix)
lstm_recognizer_->LoadDictionary(tessdata_path.string(), language);
} else {
tprintf("Error: LSTM requested, but not present!! Loading tesseract.\n");
tessedit_ocr_engine_mode.set_value(OEM_TESSERACT_ONLY);
}
} }
#endif #endif
// Init ParamsModel. // Init ParamsModel.
...@@ -409,8 +424,7 @@ int Tesseract::init_tesseract_internal( ...@@ -409,8 +424,7 @@ int Tesseract::init_tesseract_internal(
// If only Cube will be used, skip loading Tesseract classifier's // If only Cube will be used, skip loading Tesseract classifier's
// pre-trained templates. // pre-trained templates.
bool init_tesseract_classifier = bool init_tesseract_classifier =
(tessedit_ocr_engine_mode == OEM_TESSERACT_ONLY || tessedit_ocr_engine_mode != OEM_CUBE_ONLY;
tessedit_ocr_engine_mode == OEM_TESSERACT_CUBE_COMBINED);
// If only Cube will be used and if it has its own Unicharset, // If only Cube will be used and if it has its own Unicharset,
// skip initializing permuter and loading Tesseract Dawgs. // skip initializing permuter and loading Tesseract Dawgs.
bool init_dict = bool init_dict =
...@@ -468,7 +482,9 @@ int Tesseract::init_tesseract_lm(const char *arg0, ...@@ -468,7 +482,9 @@ int Tesseract::init_tesseract_lm(const char *arg0,
if (!init_tesseract_lang_data(arg0, textbase, language, OEM_TESSERACT_ONLY, if (!init_tesseract_lang_data(arg0, textbase, language, OEM_TESSERACT_ONLY,
NULL, 0, NULL, NULL, false)) NULL, 0, NULL, NULL, false))
return -1; return -1;
getDict().Load(Dict::GlobalDawgCache()); getDict().SetupForLoad(Dict::GlobalDawgCache());
getDict().Load(tessdata_manager.GetDataFileName().string(), lang);
getDict().FinishLoad();
tessdata_manager.End(); tessdata_manager.End();
return 0; return 0;
} }
......
...@@ -49,6 +49,7 @@ ...@@ -49,6 +49,7 @@
#include "equationdetect.h" #include "equationdetect.h"
#include "globals.h" #include "globals.h"
#ifndef NO_CUBE_BUILD #ifndef NO_CUBE_BUILD
#include "lstmrecognizer.h"
#include "tesseract_cube_combiner.h" #include "tesseract_cube_combiner.h"
#endif #endif
...@@ -65,6 +66,9 @@ Tesseract::Tesseract() ...@@ -65,6 +66,9 @@ Tesseract::Tesseract()
"Generate training data from boxed chars", this->params()), "Generate training data from boxed chars", this->params()),
BOOL_MEMBER(tessedit_make_boxes_from_boxes, false, BOOL_MEMBER(tessedit_make_boxes_from_boxes, false,
"Generate more boxes from boxed chars", this->params()), "Generate more boxes from boxed chars", this->params()),
BOOL_MEMBER(tessedit_train_line_recognizer, false,
"Break input into lines and remap boxes if present",
this->params()),
BOOL_MEMBER(tessedit_dump_pageseg_images, false, BOOL_MEMBER(tessedit_dump_pageseg_images, false,
"Dump intermediate images made during page segmentation", "Dump intermediate images made during page segmentation",
this->params()), this->params()),
...@@ -222,6 +226,8 @@ Tesseract::Tesseract() ...@@ -222,6 +226,8 @@ Tesseract::Tesseract()
"(more accurate)", "(more accurate)",
this->params()), this->params()),
INT_MEMBER(cube_debug_level, 0, "Print cube debug info.", this->params()), INT_MEMBER(cube_debug_level, 0, "Print cube debug info.", this->params()),
BOOL_MEMBER(lstm_use_matrix, 1,
"Use ratings matrix/beam search with lstm", this->params()),
STRING_MEMBER(outlines_odd, "%| ", "Non standard number of outlines", STRING_MEMBER(outlines_odd, "%| ", "Non standard number of outlines",
this->params()), this->params()),
STRING_MEMBER(outlines_2, "ij!?%\":;", "Non standard number of outlines", STRING_MEMBER(outlines_2, "ij!?%\":;", "Non standard number of outlines",
...@@ -605,6 +611,7 @@ Tesseract::Tesseract() ...@@ -605,6 +611,7 @@ Tesseract::Tesseract()
pix_binary_(NULL), pix_binary_(NULL),
cube_binary_(NULL), cube_binary_(NULL),
pix_grey_(NULL), pix_grey_(NULL),
pix_original_(NULL),
pix_thresholds_(NULL), pix_thresholds_(NULL),
source_resolution_(0), source_resolution_(0),
textord_(this), textord_(this),
...@@ -619,11 +626,16 @@ Tesseract::Tesseract() ...@@ -619,11 +626,16 @@ Tesseract::Tesseract()
cube_cntxt_(NULL), cube_cntxt_(NULL),
tess_cube_combiner_(NULL), tess_cube_combiner_(NULL),
#endif #endif
equ_detect_(NULL) { equ_detect_(NULL),
#ifndef ANDROID_BUILD
lstm_recognizer_(NULL),
#endif
train_line_page_num_(0) {
} }
Tesseract::~Tesseract() { Tesseract::~Tesseract() {
Clear(); Clear();
pixDestroy(&pix_original_);
end_tesseract(); end_tesseract();
sub_langs_.delete_data_pointers(); sub_langs_.delete_data_pointers();
#ifndef NO_CUBE_BUILD #ifndef NO_CUBE_BUILD
...@@ -636,6 +648,8 @@ Tesseract::~Tesseract() { ...@@ -636,6 +648,8 @@ Tesseract::~Tesseract() {
delete tess_cube_combiner_; delete tess_cube_combiner_;
tess_cube_combiner_ = NULL; tess_cube_combiner_ = NULL;
} }
delete lstm_recognizer_;
lstm_recognizer_ = NULL;
#endif #endif
} }
......
...@@ -102,7 +102,10 @@ class CubeLineObject; ...@@ -102,7 +102,10 @@ class CubeLineObject;
class CubeObject; class CubeObject;
class CubeRecoContext; class CubeRecoContext;
#endif #endif
class DocumentData;
class EquationDetect; class EquationDetect;
class ImageData;
class LSTMRecognizer;
class Tesseract; class Tesseract;
#ifndef NO_CUBE_BUILD #ifndef NO_CUBE_BUILD
class TesseractCubeCombiner; class TesseractCubeCombiner;
...@@ -189,7 +192,7 @@ class Tesseract : public Wordrec { ...@@ -189,7 +192,7 @@ class Tesseract : public Wordrec {
} }
// Destroy any existing pix and return a pointer to the pointer. // Destroy any existing pix and return a pointer to the pointer.
Pix** mutable_pix_binary() { Pix** mutable_pix_binary() {
Clear(); pixDestroy(&pix_binary_);
return &pix_binary_; return &pix_binary_;
} }
Pix* pix_binary() const { Pix* pix_binary() const {
...@@ -202,16 +205,20 @@ class Tesseract : public Wordrec { ...@@ -202,16 +205,20 @@ class Tesseract : public Wordrec {
pixDestroy(&pix_grey_); pixDestroy(&pix_grey_);
pix_grey_ = grey_pix; pix_grey_ = grey_pix;
} }
// Returns a pointer to a Pix representing the best available image of the Pix* pix_original() const { return pix_original_; }
// page. The image will be 8-bit grey if the input was grey or color. Note // Takes ownership of the given original_pix.
// that in grey 0 is black and 255 is white. If the input was binary, then void set_pix_original(Pix* original_pix) {
// the returned Pix will be binary. Note that here black is 1 and white is 0. pixDestroy(&pix_original_);
// To tell the difference pixGetDepth() will return 8 or 1. pix_original_ = original_pix;
// In either case, the return value is a borrowed Pix, and should not be
// deleted or pixDestroyed.
Pix* BestPix() const {
return pix_grey_ != NULL ? pix_grey_ : pix_binary_;
} }
// Returns a pointer to a Pix representing the best available (original) image
// of the page. Can be of any bit depth, but never color-mapped, as that has
// always been dealt with. Note that in grey and color, 0 is black and 255 is
// white. If the input was binary, then black is 1 and white is 0.
// To tell the difference pixGetDepth() will return 32, 8 or 1.
// In any case, the return value is a borrowed Pix, and should not be
// deleted or pixDestroyed.
Pix* BestPix() const { return pix_original_; }
void set_pix_thresholds(Pix* thresholds) { void set_pix_thresholds(Pix* thresholds) {
pixDestroy(&pix_thresholds_); pixDestroy(&pix_thresholds_);
pix_thresholds_ = thresholds; pix_thresholds_ = thresholds;
...@@ -263,6 +270,15 @@ class Tesseract : public Wordrec { ...@@ -263,6 +270,15 @@ class Tesseract : public Wordrec {
} }
return false; return false;
} }
// Returns true if any language uses the LSTM.
bool AnyLSTMLang() const {
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) return true;
for (int i = 0; i < sub_langs_.size(); ++i) {
if (sub_langs_[i]->tessedit_ocr_engine_mode == OEM_LSTM_ONLY)
return true;
}
return false;
}
void SetBlackAndWhitelist(); void SetBlackAndWhitelist();
...@@ -293,6 +309,48 @@ class Tesseract : public Wordrec { ...@@ -293,6 +309,48 @@ class Tesseract : public Wordrec {
// par_control.cpp // par_control.cpp
void PrerecAllWordsPar(const GenericVector<WordData>& words); void PrerecAllWordsPar(const GenericVector<WordData>& words);
//// linerec.cpp
// Generates training data for training a line recognizer, eg LSTM.
// Breaks the page into lines, according to the boxes, and writes them to a
// serialized DocumentData based on output_basename.
void TrainLineRecognizer(const STRING& input_imagename,
const STRING& output_basename,
BLOCK_LIST *block_list);
// Generates training data for training a line recognizer, eg LSTM.
// Breaks the boxes into lines, normalizes them, converts to ImageData and
// appends them to the given training_data.
void TrainFromBoxes(const GenericVector<TBOX>& boxes,
const GenericVector<STRING>& texts,
BLOCK_LIST *block_list,
DocumentData* training_data);
// Returns an Imagedata containing the image of the given textline,
// and ground truth boxes/truth text if available in the input.
// The image is not normalized in any way.
ImageData* GetLineData(const TBOX& line_box,
const GenericVector<TBOX>& boxes,
const GenericVector<STRING>& texts,
int start_box, int end_box,
const BLOCK& block);
// Helper gets the image of a rectangle, using the block.re_rotation() if
// needed to get to the image, and rotating the result back to horizontal
// layout. (CJK characters will be on their left sides) The vertical text flag
// is set in the returned ImageData if the text was originally vertical, which
// can be used to invoke a different CJK recognition engine. The revised_box
// is also returned to enable calculation of output bounding boxes.
ImageData* GetRectImage(const TBOX& box, const BLOCK& block, int padding,
TBOX* revised_box) const;
// Top-level function recognizes a single raw line.
void RecogRawLine(PAGE_RES* page_res);
// Recognizes a word or group of words, converting to WERD_RES in *words.
// Analogous to classify_word_pass1, but can handle a group of words as well.
void LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word,
PointerVector<WERD_RES>* words);
// Apply segmentation search to the given set of words, within the constraints
// of the existing ratings matrix. If there is already a best_choice on a word
// leaves it untouched and just sets the done/accepted etc flags.
void SearchWords(PointerVector<WERD_RES>* words);
//// control.h ///////////////////////////////////////////////////////// //// control.h /////////////////////////////////////////////////////////
bool ProcessTargetWord(const TBOX& word_box, const TBOX& target_word_box, bool ProcessTargetWord(const TBOX& word_box, const TBOX& target_word_box,
const char* word_config, int pass); const char* word_config, int pass);
...@@ -783,6 +841,8 @@ class Tesseract : public Wordrec { ...@@ -783,6 +841,8 @@ class Tesseract : public Wordrec {
"Generate training data from boxed chars"); "Generate training data from boxed chars");
BOOL_VAR_H(tessedit_make_boxes_from_boxes, false, BOOL_VAR_H(tessedit_make_boxes_from_boxes, false,
"Generate more boxes from boxed chars"); "Generate more boxes from boxed chars");
BOOL_VAR_H(tessedit_train_line_recognizer, false,
"Break input into lines and remap boxes if present");
BOOL_VAR_H(tessedit_dump_pageseg_images, false, BOOL_VAR_H(tessedit_dump_pageseg_images, false,
"Dump intermediate images made during page segmentation"); "Dump intermediate images made during page segmentation");
INT_VAR_H(tessedit_pageseg_mode, PSM_SINGLE_BLOCK, INT_VAR_H(tessedit_pageseg_mode, PSM_SINGLE_BLOCK,
...@@ -891,6 +951,7 @@ class Tesseract : public Wordrec { ...@@ -891,6 +951,7 @@ class Tesseract : public Wordrec {
"Run paragraph detection on the post-text-recognition " "Run paragraph detection on the post-text-recognition "
"(more accurate)"); "(more accurate)");
INT_VAR_H(cube_debug_level, 1, "Print cube debug info."); INT_VAR_H(cube_debug_level, 1, "Print cube debug info.");
BOOL_VAR_H(lstm_use_matrix, 1, "Use ratings matrix/beam searct with lstm");
STRING_VAR_H(outlines_odd, "%| ", "Non standard number of outlines"); STRING_VAR_H(outlines_odd, "%| ", "Non standard number of outlines");
STRING_VAR_H(outlines_2, "ij!?%\":;", "Non standard number of outlines"); STRING_VAR_H(outlines_2, "ij!?%\":;", "Non standard number of outlines");
BOOL_VAR_H(docqual_excuse_outline_errs, false, BOOL_VAR_H(docqual_excuse_outline_errs, false,
...@@ -1174,6 +1235,8 @@ class Tesseract : public Wordrec { ...@@ -1174,6 +1235,8 @@ class Tesseract : public Wordrec {
Pix* cube_binary_; Pix* cube_binary_;
// Grey-level input image if the input was not binary, otherwise NULL. // Grey-level input image if the input was not binary, otherwise NULL.
Pix* pix_grey_; Pix* pix_grey_;
// Original input image. Color if the input was color.
Pix* pix_original_;
// Thresholds that were used to generate the thresholded image from grey. // Thresholds that were used to generate the thresholded image from grey.
Pix* pix_thresholds_; Pix* pix_thresholds_;
// Input image resolution after any scaling. The resolution is not well // Input image resolution after any scaling. The resolution is not well
...@@ -1205,6 +1268,10 @@ class Tesseract : public Wordrec { ...@@ -1205,6 +1268,10 @@ class Tesseract : public Wordrec {
#endif #endif
// Equation detector. Note: this pointer is NOT owned by the class. // Equation detector. Note: this pointer is NOT owned by the class.
EquationDetect* equ_detect_; EquationDetect* equ_detect_;
// LSTM recognizer, if available.
LSTMRecognizer* lstm_recognizer_;
// Output "page" number (actually line number) using TrainLineRecognizer.
int train_line_page_num_;
}; };
} // namespace tesseract } // namespace tesseract
......
...@@ -152,19 +152,27 @@ void ImageThresholder::SetImage(const Pix* pix) { ...@@ -152,19 +152,27 @@ void ImageThresholder::SetImage(const Pix* pix) {
int depth; int depth;
pixGetDimensions(src, &image_width_, &image_height_, &depth); pixGetDimensions(src, &image_width_, &image_height_, &depth);
// Convert the image as necessary so it is one of binary, plain RGB, or // Convert the image as necessary so it is one of binary, plain RGB, or
// 8 bit with no colormap. // 8 bit with no colormap. Guarantee that we always end up with our own copy,
if (depth > 1 && depth < 8) { // not just a clone of the input.
if (pixGetColormap(src)) {
Pix* tmp = pixRemoveColormap(src, REMOVE_CMAP_BASED_ON_SRC);
depth = pixGetDepth(tmp);
if (depth > 1 && depth < 8) {
pix_ = pixConvertTo8(tmp, false);
pixDestroy(&tmp);
} else {
pix_ = tmp;
}
} else if (depth > 1 && depth < 8) {
pix_ = pixConvertTo8(src, false); pix_ = pixConvertTo8(src, false);
} else if (pixGetColormap(src)) {
pix_ = pixRemoveColormap(src, REMOVE_CMAP_BASED_ON_SRC);
} else { } else {
pix_ = pixClone(src); pix_ = pixCopy(NULL, src);
} }
depth = pixGetDepth(pix_); depth = pixGetDepth(pix_);
pix_channels_ = depth / 8; pix_channels_ = depth / 8;
pix_wpl_ = pixGetWpl(pix_); pix_wpl_ = pixGetWpl(pix_);
scale_ = 1; scale_ = 1;
estimated_res_ = yres_ = pixGetYRes(src); estimated_res_ = yres_ = pixGetYRes(pix_);
Init(); Init();
} }
......
...@@ -24,12 +24,18 @@ ...@@ -24,12 +24,18 @@
#include "imagedata.h" #include "imagedata.h"
#include <unistd.h>
#include "allheaders.h" #include "allheaders.h"
#include "boxread.h" #include "boxread.h"
#include "callcpp.h" #include "callcpp.h"
#include "helpers.h" #include "helpers.h"
#include "tprintf.h" #include "tprintf.h"
// Number of documents to read ahead while training. Doesn't need to be very
// large.
const int kMaxReadAhead = 8;
namespace tesseract { namespace tesseract {
WordFeature::WordFeature() : x_(0), y_(0), dir_(0) { WordFeature::WordFeature() : x_(0), y_(0), dir_(0) {
...@@ -182,6 +188,19 @@ bool ImageData::DeSerialize(bool swap, TFile* fp) { ...@@ -182,6 +188,19 @@ bool ImageData::DeSerialize(bool swap, TFile* fp) {
return true; return true;
} }
// As DeSerialize, but only seeks past the data - hence a static method.
bool ImageData::SkipDeSerialize(bool swap, TFile* fp) {
if (!STRING::SkipDeSerialize(swap, fp)) return false;
inT32 page_number;
if (fp->FRead(&page_number, sizeof(page_number), 1) != 1) return false;
if (!GenericVector<char>::SkipDeSerialize(swap, fp)) return false;
if (!STRING::SkipDeSerialize(swap, fp)) return false;
if (!GenericVector<TBOX>::SkipDeSerialize(swap, fp)) return false;
if (!GenericVector<STRING>::SkipDeSerializeClasses(swap, fp)) return false;
inT8 vertical = 0;
return fp->FRead(&vertical, sizeof(vertical), 1) == 1;
}
// Saves the given Pix as a PNG-encoded string and destroys it. // Saves the given Pix as a PNG-encoded string and destroys it.
void ImageData::SetPix(Pix* pix) { void ImageData::SetPix(Pix* pix) {
SetPixInternal(pix, &image_data_); SetPixInternal(pix, &image_data_);
...@@ -195,11 +214,12 @@ Pix* ImageData::GetPix() const { ...@@ -195,11 +214,12 @@ Pix* ImageData::GetPix() const {
// Gets anything and everything with a non-NULL pointer, prescaled to a // Gets anything and everything with a non-NULL pointer, prescaled to a
// given target_height (if 0, then the original image height), and aligned. // given target_height (if 0, then the original image height), and aligned.
// Also returns (if not NULL) the width and height of the scaled image. // Also returns (if not NULL) the width and height of the scaled image.
// The return value is the scale factor that was applied to the image to // The return value is the scaled Pix, which must be pixDestroyed after use,
// achieve the target_height. // and scale_factor (if not NULL) is set to the scale factor that was applied
float ImageData::PreScale(int target_height, Pix** pix, // to the image to achieve the target_height.
int* scaled_width, int* scaled_height, Pix* ImageData::PreScale(int target_height, float* scale_factor,
GenericVector<TBOX>* boxes) const { int* scaled_width, int* scaled_height,
GenericVector<TBOX>* boxes) const {
int input_width = 0; int input_width = 0;
int input_height = 0; int input_height = 0;
Pix* src_pix = GetPix(); Pix* src_pix = GetPix();
...@@ -213,19 +233,14 @@ float ImageData::PreScale(int target_height, Pix** pix, ...@@ -213,19 +233,14 @@ float ImageData::PreScale(int target_height, Pix** pix,
*scaled_width = IntCastRounded(im_factor * input_width); *scaled_width = IntCastRounded(im_factor * input_width);
if (scaled_height != NULL) if (scaled_height != NULL)
*scaled_height = target_height; *scaled_height = target_height;
if (pix != NULL) { // Get the scaled image.
// Get the scaled image. Pix* pix = pixScale(src_pix, im_factor, im_factor);
pixDestroy(pix); if (pix == NULL) {
*pix = pixScale(src_pix, im_factor, im_factor); tprintf("Scaling pix of size %d, %d by factor %g made null pix!!\n",
if (*pix == NULL) { input_width, input_height, im_factor);
tprintf("Scaling pix of size %d, %d by factor %g made null pix!!\n",
input_width, input_height, im_factor);
}
if (scaled_width != NULL)
*scaled_width = pixGetWidth(*pix);
if (scaled_height != NULL)
*scaled_height = pixGetHeight(*pix);
} }
if (scaled_width != NULL) *scaled_width = pixGetWidth(pix);
if (scaled_height != NULL) *scaled_height = pixGetHeight(pix);
pixDestroy(&src_pix); pixDestroy(&src_pix);
if (boxes != NULL) { if (boxes != NULL) {
// Get the boxes. // Get the boxes.
...@@ -241,7 +256,8 @@ float ImageData::PreScale(int target_height, Pix** pix, ...@@ -241,7 +256,8 @@ float ImageData::PreScale(int target_height, Pix** pix,
boxes->push_back(box); boxes->push_back(box);
} }
} }
return im_factor; if (scale_factor != NULL) *scale_factor = im_factor;
return pix;
} }
int ImageData::MemoryUsed() const { int ImageData::MemoryUsed() const {
...@@ -266,19 +282,20 @@ void ImageData::Display() const { ...@@ -266,19 +282,20 @@ void ImageData::Display() const {
// Draw the boxes. // Draw the boxes.
win->Pen(ScrollView::RED); win->Pen(ScrollView::RED);
win->Brush(ScrollView::NONE); win->Brush(ScrollView::NONE);
win->TextAttributes("Arial", kTextSize, false, false, false); int text_size = kTextSize;
for (int b = 0; b < boxes_.size(); ++b) { if (!boxes_.empty() && boxes_[0].height() * 2 < text_size)
boxes_[b].plot(win); text_size = boxes_[0].height() * 2;
win->Text(boxes_[b].left(), height + kTextSize, box_texts_[b].string()); win->TextAttributes("Arial", text_size, false, false, false);
TBOX scaled(boxes_[b]); if (!boxes_.empty()) {
scaled.scale(256.0 / height); for (int b = 0; b < boxes_.size(); ++b) {
scaled.plot(win); boxes_[b].plot(win);
win->Text(boxes_[b].left(), height + kTextSize, box_texts_[b].string());
}
} else {
// The full transcription.
win->Pen(ScrollView::CYAN);
win->Text(0, height + kTextSize * 2, transcription_.string());
} }
// The full transcription.
win->Pen(ScrollView::CYAN);
win->Text(0, height + kTextSize * 2, transcription_.string());
// Add the features.
win->Pen(ScrollView::GREEN);
win->Update(); win->Update();
window_wait(win); window_wait(win);
#endif #endif
...@@ -340,27 +357,51 @@ bool ImageData::AddBoxes(const char* box_text) { ...@@ -340,27 +357,51 @@ bool ImageData::AddBoxes(const char* box_text) {
return false; return false;
} }
// Thread function to call ReCachePages.
void* ReCachePagesFunc(void* data) {
DocumentData* document_data = reinterpret_cast<DocumentData*>(data);
document_data->ReCachePages();
return NULL;
}
DocumentData::DocumentData(const STRING& name) DocumentData::DocumentData(const STRING& name)
: document_name_(name), pages_offset_(0), total_pages_(0), : document_name_(name),
memory_used_(0), max_memory_(0), reader_(NULL) {} pages_offset_(-1),
total_pages_(-1),
memory_used_(0),
max_memory_(0),
reader_(NULL) {}
DocumentData::~DocumentData() {} DocumentData::~DocumentData() {
SVAutoLock lock_p(&pages_mutex_);
SVAutoLock lock_g(&general_mutex_);
}
// Reads all the pages in the given lstmf filename to the cache. The reader // Reads all the pages in the given lstmf filename to the cache. The reader
// is used to read the file. // is used to read the file.
bool DocumentData::LoadDocument(const char* filename, const char* lang, bool DocumentData::LoadDocument(const char* filename, const char* lang,
int start_page, inT64 max_memory, int start_page, inT64 max_memory,
FileReader reader) { FileReader reader) {
SetDocument(filename, lang, max_memory, reader);
pages_offset_ = start_page;
return ReCachePages();
}
// Sets up the document, without actually loading it.
void DocumentData::SetDocument(const char* filename, const char* lang,
inT64 max_memory, FileReader reader) {
SVAutoLock lock_p(&pages_mutex_);
SVAutoLock lock(&general_mutex_);
document_name_ = filename; document_name_ = filename;
lang_ = lang; lang_ = lang;
pages_offset_ = start_page; pages_offset_ = -1;
max_memory_ = max_memory; max_memory_ = max_memory;
reader_ = reader; reader_ = reader;
return ReCachePages();
} }
// Writes all the pages to the given filename. Returns false on error. // Writes all the pages to the given filename. Returns false on error.
bool DocumentData::SaveDocument(const char* filename, FileWriter writer) { bool DocumentData::SaveDocument(const char* filename, FileWriter writer) {
SVAutoLock lock(&pages_mutex_);
TFile fp; TFile fp;
fp.OpenWrite(NULL); fp.OpenWrite(NULL);
if (!pages_.Serialize(&fp) || !fp.CloseWrite(filename, writer)) { if (!pages_.Serialize(&fp) || !fp.CloseWrite(filename, writer)) {
...@@ -370,112 +411,166 @@ bool DocumentData::SaveDocument(const char* filename, FileWriter writer) { ...@@ -370,112 +411,166 @@ bool DocumentData::SaveDocument(const char* filename, FileWriter writer) {
return true; return true;
} }
bool DocumentData::SaveToBuffer(GenericVector<char>* buffer) { bool DocumentData::SaveToBuffer(GenericVector<char>* buffer) {
SVAutoLock lock(&pages_mutex_);
TFile fp; TFile fp;
fp.OpenWrite(buffer); fp.OpenWrite(buffer);
return pages_.Serialize(&fp); return pages_.Serialize(&fp);
} }
// Adds the given page data to this document, counting up memory.
void DocumentData::AddPageToDocument(ImageData* page) {
SVAutoLock lock(&pages_mutex_);
pages_.push_back(page);
set_memory_used(memory_used() + page->MemoryUsed());
}
// If the given index is not currently loaded, loads it using a separate
// thread.
void DocumentData::LoadPageInBackground(int index) {
ImageData* page = NULL;
if (IsPageAvailable(index, &page)) return;
SVAutoLock lock(&pages_mutex_);
if (pages_offset_ == index) return;
pages_offset_ = index;
pages_.clear();
SVSync::StartThread(ReCachePagesFunc, this);
}
// Returns a pointer to the page with the given index, modulo the total // Returns a pointer to the page with the given index, modulo the total
// number of pages, recaching if needed. // number of pages. Blocks until the background load is completed.
const ImageData* DocumentData::GetPage(int index) { const ImageData* DocumentData::GetPage(int index) {
index = Modulo(index, total_pages_); ImageData* page = NULL;
if (index < pages_offset_ || index >= pages_offset_ + pages_.size()) { while (!IsPageAvailable(index, &page)) {
pages_offset_ = index; // If there is no background load scheduled, schedule one now.
if (!ReCachePages()) return NULL; pages_mutex_.Lock();
bool needs_loading = pages_offset_ != index;
pages_mutex_.Unlock();
if (needs_loading) LoadPageInBackground(index);
// We can't directly load the page, or the background load will delete it
// while the caller is using it, so give it a chance to work.
sleep(1);
} }
return pages_[index - pages_offset_]; return page;
}
// Returns true if the requested page is available, and provides a pointer,
// which may be NULL if the document is empty. May block, even though it
// doesn't guarantee to return true.
bool DocumentData::IsPageAvailable(int index, ImageData** page) {
SVAutoLock lock(&pages_mutex_);
int num_pages = NumPages();
if (num_pages == 0 || index < 0) {
*page = NULL; // Empty Document.
return true;
}
if (num_pages > 0) {
index = Modulo(index, num_pages);
if (pages_offset_ <= index && index < pages_offset_ + pages_.size()) {
*page = pages_[index - pages_offset_]; // Page is available already.
return true;
}
}
return false;
} }
// Loads as many pages can fit in max_memory_ starting at index pages_offset_. // Removes all pages from memory and frees the memory, but does not forget
// the document metadata.
inT64 DocumentData::UnCache() {
SVAutoLock lock(&pages_mutex_);
inT64 memory_saved = memory_used();
pages_.clear();
pages_offset_ = -1;
set_total_pages(-1);
set_memory_used(0);
tprintf("Unloaded document %s, saving %d memory\n", document_name_.string(),
memory_saved);
return memory_saved;
}
// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_
// starting at index pages_offset_.
bool DocumentData::ReCachePages() { bool DocumentData::ReCachePages() {
SVAutoLock lock(&pages_mutex_);
// Read the file. // Read the file.
set_total_pages(0);
set_memory_used(0);
int loaded_pages = 0;
pages_.truncate(0);
TFile fp; TFile fp;
if (!fp.Open(document_name_, reader_)) return false; if (!fp.Open(document_name_, reader_) ||
memory_used_ = 0; !PointerVector<ImageData>::DeSerializeSize(false, &fp, &loaded_pages) ||
if (!pages_.DeSerialize(false, &fp)) { loaded_pages <= 0) {
tprintf("Deserialize failed: %s\n", document_name_.string()); tprintf("Deserialize header failed: %s\n", document_name_.string());
pages_.truncate(0);
return false; return false;
} }
total_pages_ = pages_.size(); pages_offset_ %= loaded_pages;
pages_offset_ %= total_pages_; // Skip pages before the first one we want, and load the rest until max
// Delete pages before the first one we want, and relocate the rest. // memory and skip the rest after that.
int page; int page;
for (page = 0; page < pages_.size(); ++page) { for (page = 0; page < loaded_pages; ++page) {
if (page < pages_offset_) { if (page < pages_offset_ ||
delete pages_[page]; (max_memory_ > 0 && memory_used() > max_memory_)) {
pages_[page] = NULL; if (!PointerVector<ImageData>::DeSerializeSkip(false, &fp)) break;
} else { } else {
ImageData* image_data = pages_[page]; if (!pages_.DeSerializeElement(false, &fp)) break;
if (max_memory_ > 0 && page > pages_offset_ && ImageData* image_data = pages_.back();
memory_used_ + image_data->MemoryUsed() > max_memory_)
break; // Don't go over memory quota unless the first image.
if (image_data->imagefilename().length() == 0) { if (image_data->imagefilename().length() == 0) {
image_data->set_imagefilename(document_name_); image_data->set_imagefilename(document_name_);
image_data->set_page_number(page); image_data->set_page_number(page);
} }
image_data->set_language(lang_); image_data->set_language(lang_);
memory_used_ += image_data->MemoryUsed(); set_memory_used(memory_used() + image_data->MemoryUsed());
if (pages_offset_ != 0) {
pages_[page - pages_offset_] = image_data;
pages_[page] = NULL;
}
} }
} }
pages_.truncate(page - pages_offset_); if (page < loaded_pages) {
tprintf("Loaded %d/%d pages (%d-%d) of document %s\n", tprintf("Deserialize failed: %s read %d/%d pages\n",
pages_.size(), total_pages_, pages_offset_, document_name_.string(), page, loaded_pages);
pages_offset_ + pages_.size(), document_name_.string()); pages_.truncate(0);
} else {
tprintf("Loaded %d/%d pages (%d-%d) of document %s\n", pages_.size(),
loaded_pages, pages_offset_, pages_offset_ + pages_.size(),
document_name_.string());
}
set_total_pages(loaded_pages);
return !pages_.empty(); return !pages_.empty();
} }
// Adds the given page data to this document, counting up memory.
void DocumentData::AddPageToDocument(ImageData* page) {
pages_.push_back(page);
memory_used_ += page->MemoryUsed();
}
// A collection of DocumentData that knows roughly how much memory it is using. // A collection of DocumentData that knows roughly how much memory it is using.
DocumentCache::DocumentCache(inT64 max_memory) DocumentCache::DocumentCache(inT64 max_memory)
: total_pages_(0), memory_used_(0), max_memory_(max_memory) {} : num_pages_per_doc_(0), max_memory_(max_memory) {}
DocumentCache::~DocumentCache() {} DocumentCache::~DocumentCache() {}
// Adds all the documents in the list of filenames, counting memory. // Adds all the documents in the list of filenames, counting memory.
// The reader is used to read the files. // The reader is used to read the files.
bool DocumentCache::LoadDocuments(const GenericVector<STRING>& filenames, bool DocumentCache::LoadDocuments(const GenericVector<STRING>& filenames,
const char* lang, FileReader reader) { const char* lang,
inT64 fair_share_memory = max_memory_ / filenames.size(); CachingStrategy cache_strategy,
FileReader reader) {
cache_strategy_ = cache_strategy;
inT64 fair_share_memory = 0;
// In the round-robin case, each DocumentData handles restricting its content
// to its fair share of memory. In the sequential case, DocumentCache
// determines which DocumentDatas are held entirely in memory.
if (cache_strategy_ == CS_ROUND_ROBIN)
fair_share_memory = max_memory_ / filenames.size();
for (int arg = 0; arg < filenames.size(); ++arg) { for (int arg = 0; arg < filenames.size(); ++arg) {
STRING filename = filenames[arg]; STRING filename = filenames[arg];
DocumentData* document = new DocumentData(filename); DocumentData* document = new DocumentData(filename);
if (document->LoadDocument(filename.string(), lang, 0, document->SetDocument(filename.string(), lang, fair_share_memory, reader);
fair_share_memory, reader)) { AddToCache(document);
AddToCache(document); }
} else { if (!documents_.empty()) {
tprintf("Failed to load image %s!\n", filename.string()); // Try to get the first page now to verify the list of filenames.
delete document; if (GetPageBySerial(0) != NULL) return true;
} tprintf("Load of page 0 failed!\n");
} }
tprintf("Loaded %d pages, total %gMB\n", return false;
total_pages_, memory_used_ / 1048576.0);
return total_pages_ > 0;
} }
// Adds document to the cache, throwing out other documents if needed. // Adds document to the cache.
bool DocumentCache::AddToCache(DocumentData* data) { bool DocumentCache::AddToCache(DocumentData* data) {
inT64 new_memory = data->memory_used(); inT64 new_memory = data->memory_used();
memory_used_ += new_memory;
documents_.push_back(data); documents_.push_back(data);
total_pages_ += data->NumPages();
// Delete the first item in the array, and other pages of the same name
// while memory is full.
while (memory_used_ >= max_memory_ && max_memory_ > 0) {
tprintf("Memory used=%lld vs max=%lld, discarding doc of size %lld\n",
memory_used_ , max_memory_, documents_[0]->memory_used());
memory_used_ -= documents_[0]->memory_used();
total_pages_ -= documents_[0]->NumPages();
documents_.remove(0);
}
return true; return true;
} }
...@@ -488,11 +583,104 @@ DocumentData* DocumentCache::FindDocument(const STRING& document_name) const { ...@@ -488,11 +583,104 @@ DocumentData* DocumentCache::FindDocument(const STRING& document_name) const {
return NULL; return NULL;
} }
// Returns the total number of pages in an epoch. For CS_ROUND_ROBIN cache
// strategy, could take a long time.
int DocumentCache::TotalPages() {
if (cache_strategy_ == CS_SEQUENTIAL) {
// In sequential mode, we assume each doc has the same number of pages
// whether it is true or not.
if (num_pages_per_doc_ == 0) GetPageSequential(0);
return num_pages_per_doc_ * documents_.size();
}
int total_pages = 0;
int num_docs = documents_.size();
for (int d = 0; d < num_docs; ++d) {
// We have to load a page to make NumPages() valid.
documents_[d]->GetPage(0);
total_pages += documents_[d]->NumPages();
}
return total_pages;
}
// Returns a page by serial number, selecting them in a round-robin fashion // Returns a page by serial number, selecting them in a round-robin fashion
// from all the documents. // from all the documents. Highly disk-intensive, but doesn't need samples
const ImageData* DocumentCache::GetPageBySerial(int serial) { // to be shuffled between files to begin with.
int document_index = serial % documents_.size(); const ImageData* DocumentCache::GetPageRoundRobin(int serial) {
return documents_[document_index]->GetPage(serial / documents_.size()); int num_docs = documents_.size();
int doc_index = serial % num_docs;
const ImageData* doc = documents_[doc_index]->GetPage(serial / num_docs);
for (int offset = 1; offset <= kMaxReadAhead && offset < num_docs; ++offset) {
doc_index = (serial + offset) % num_docs;
int page = (serial + offset) / num_docs;
documents_[doc_index]->LoadPageInBackground(page);
}
return doc;
}
// Returns a page by serial number, selecting them in sequence from each file.
// Requires the samples to be shuffled between the files to give a random or
// uniform distribution of data. Less disk-intensive than GetPageRoundRobin.
const ImageData* DocumentCache::GetPageSequential(int serial) {
int num_docs = documents_.size();
ASSERT_HOST(num_docs > 0);
if (num_pages_per_doc_ == 0) {
// Use the pages in the first doc as the number of pages in each doc.
documents_[0]->GetPage(0);
num_pages_per_doc_ = documents_[0]->NumPages();
if (num_pages_per_doc_ == 0) {
tprintf("First document cannot be empty!!\n");
ASSERT_HOST(num_pages_per_doc_ > 0);
}
// Get rid of zero now if we don't need it.
if (serial / num_pages_per_doc_ % num_docs > 0) documents_[0]->UnCache();
}
int doc_index = serial / num_pages_per_doc_ % num_docs;
const ImageData* doc =
documents_[doc_index]->GetPage(serial % num_pages_per_doc_);
// Count up total memory. Background loading makes it more complicated to
// keep a running count.
inT64 total_memory = 0;
for (int d = 0; d < num_docs; ++d) {
total_memory += documents_[d]->memory_used();
}
if (total_memory >= max_memory_) {
// Find something to un-cache.
// If there are more than 3 in front, then serial is from the back reader
// of a pair of readers. If we un-cache from in-front-2 to 2-ahead, then
// we create a hole between them and then un-caching the backmost occupied
// will work for both.
int num_in_front = CountNeighbourDocs(doc_index, 1);
for (int offset = num_in_front - 2;
offset > 1 && total_memory >= max_memory_; --offset) {
int next_index = (doc_index + offset) % num_docs;
total_memory -= documents_[next_index]->UnCache();
}
// If that didn't work, the best solution is to un-cache from the back. If
// we take away the document that a 2nd reader is using, it will put it
// back and make a hole between.
int num_behind = CountNeighbourDocs(doc_index, -1);
for (int offset = num_behind; offset < 0 && total_memory >= max_memory_;
++offset) {
int next_index = (doc_index + offset + num_docs) % num_docs;
total_memory -= documents_[next_index]->UnCache();
}
}
int next_index = (doc_index + 1) % num_docs;
if (!documents_[next_index]->IsCached() && total_memory < max_memory_) {
documents_[next_index]->LoadPageInBackground(0);
}
return doc;
}
// Helper counts the number of adjacent cached neighbours of index looking in
// direction dir, ie index+dir, index+2*dir etc.
int DocumentCache::CountNeighbourDocs(int index, int dir) {
int num_docs = documents_.size();
for (int offset = dir; abs(offset) < num_docs; offset += dir) {
int offset_index = (index + offset + num_docs) % num_docs;
if (!documents_[offset_index]->IsCached()) return offset - dir;
}
return num_docs;
} }
} // namespace tesseract. } // namespace tesseract.
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "normalis.h" #include "normalis.h"
#include "rect.h" #include "rect.h"
#include "strngs.h" #include "strngs.h"
#include "svutil.h"
struct Pix; struct Pix;
...@@ -34,8 +35,22 @@ namespace tesseract { ...@@ -34,8 +35,22 @@ namespace tesseract {
const int kFeaturePadding = 2; const int kFeaturePadding = 2;
// Number of pixels to pad around text boxes. // Number of pixels to pad around text boxes.
const int kImagePadding = 4; const int kImagePadding = 4;
// Number of training images to combine into a mini-batch for training.
const int kNumPagesPerMiniBatch = 100; // Enum to determine the caching and data sequencing strategy.
enum CachingStrategy {
// Reads all of one file before moving on to the next. Requires samples to be
// shuffled across files. Uses the count of samples in the first file as
// the count in all the files to achieve high-speed random access. As a
// consequence, if subsequent files are smaller, they get entries used more
// than once, and if subsequent files are larger, some entries are not used.
// Best for larger data sets that don't fit in memory.
CS_SEQUENTIAL,
// Reads one sample from each file in rotation. Does not require shuffled
// samples, but is extremely disk-intensive. Samples in smaller files also
// get used more often than samples in larger files.
// Best for smaller data sets that mostly fit in memory.
CS_ROUND_ROBIN,
};
class WordFeature { class WordFeature {
public: public:
...@@ -103,6 +118,8 @@ class ImageData { ...@@ -103,6 +118,8 @@ class ImageData {
// Reads from the given file. Returns false in case of error. // Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed. // If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, TFile* fp); bool DeSerialize(bool swap, TFile* fp);
// As DeSerialize, but only seeks past the data - hence a static method.
static bool SkipDeSerialize(bool swap, tesseract::TFile* fp);
// Other accessors. // Other accessors.
const STRING& imagefilename() const { const STRING& imagefilename() const {
...@@ -145,11 +162,11 @@ class ImageData { ...@@ -145,11 +162,11 @@ class ImageData {
// Gets anything and everything with a non-NULL pointer, prescaled to a // Gets anything and everything with a non-NULL pointer, prescaled to a
// given target_height (if 0, then the original image height), and aligned. // given target_height (if 0, then the original image height), and aligned.
// Also returns (if not NULL) the width and height of the scaled image. // Also returns (if not NULL) the width and height of the scaled image.
// The return value is the scale factor that was applied to the image to // The return value is the scaled Pix, which must be pixDestroyed after use,
// achieve the target_height. // and scale_factor (if not NULL) is set to the scale factor that was applied
float PreScale(int target_height, Pix** pix, // to the image to achieve the target_height.
int* scaled_width, int* scaled_height, Pix* PreScale(int target_height, float* scale_factor, int* scaled_width,
GenericVector<TBOX>* boxes) const; int* scaled_height, GenericVector<TBOX>* boxes) const;
int MemoryUsed() const; int MemoryUsed() const;
...@@ -184,6 +201,8 @@ class ImageData { ...@@ -184,6 +201,8 @@ class ImageData {
// A collection of ImageData that knows roughly how much memory it is using. // A collection of ImageData that knows roughly how much memory it is using.
class DocumentData { class DocumentData {
friend void* ReCachePagesFunc(void* data);
public: public:
explicit DocumentData(const STRING& name); explicit DocumentData(const STRING& name);
~DocumentData(); ~DocumentData();
...@@ -192,6 +211,9 @@ class DocumentData { ...@@ -192,6 +211,9 @@ class DocumentData {
// is used to read the file. // is used to read the file.
bool LoadDocument(const char* filename, const char* lang, int start_page, bool LoadDocument(const char* filename, const char* lang, int start_page,
inT64 max_memory, FileReader reader); inT64 max_memory, FileReader reader);
// Sets up the document, without actually loading it.
void SetDocument(const char* filename, const char* lang, inT64 max_memory,
FileReader reader);
// Writes all the pages to the given filename. Returns false on error. // Writes all the pages to the given filename. Returns false on error.
bool SaveDocument(const char* filename, FileWriter writer); bool SaveDocument(const char* filename, FileWriter writer);
bool SaveToBuffer(GenericVector<char>* buffer); bool SaveToBuffer(GenericVector<char>* buffer);
...@@ -200,26 +222,62 @@ class DocumentData { ...@@ -200,26 +222,62 @@ class DocumentData {
void AddPageToDocument(ImageData* page); void AddPageToDocument(ImageData* page);
const STRING& document_name() const { const STRING& document_name() const {
SVAutoLock lock(&general_mutex_);
return document_name_; return document_name_;
} }
int NumPages() const { int NumPages() const {
SVAutoLock lock(&general_mutex_);
return total_pages_; return total_pages_;
} }
inT64 memory_used() const { inT64 memory_used() const {
SVAutoLock lock(&general_mutex_);
return memory_used_; return memory_used_;
} }
// If the given index is not currently loaded, loads it using a separate
// thread. Note: there are 4 cases:
// Document uncached: IsCached() returns false, total_pages_ < 0.
// Required page is available: IsPageAvailable returns true. In this case,
// total_pages_ > 0 and
// pages_offset_ <= index%total_pages_ <= pages_offset_+pages_.size()
// Pages are loaded, but the required one is not.
// The requested page is being loaded by LoadPageInBackground. In this case,
// index == pages_offset_. Once the loading starts, the pages lock is held
// until it completes, at which point IsPageAvailable will unblock and return
// true.
void LoadPageInBackground(int index);
// Returns a pointer to the page with the given index, modulo the total // Returns a pointer to the page with the given index, modulo the total
// number of pages, recaching if needed. // number of pages. Blocks until the background load is completed.
const ImageData* GetPage(int index); const ImageData* GetPage(int index);
// Returns true if the requested page is available, and provides a pointer,
// which may be NULL if the document is empty. May block, even though it
// doesn't guarantee to return true.
bool IsPageAvailable(int index, ImageData** page);
// Takes ownership of the given page index. The page is made NULL in *this. // Takes ownership of the given page index. The page is made NULL in *this.
ImageData* TakePage(int index) { ImageData* TakePage(int index) {
SVAutoLock lock(&pages_mutex_);
ImageData* page = pages_[index]; ImageData* page = pages_[index];
pages_[index] = NULL; pages_[index] = NULL;
return page; return page;
} }
// Returns true if the document is currently loaded or in the process of
// loading.
bool IsCached() const { return NumPages() >= 0; }
// Removes all pages from memory and frees the memory, but does not forget
// the document metadata. Returns the memory saved.
inT64 UnCache();
private: private:
// Loads as many pages can fit in max_memory_ starting at index pages_offset_. // Sets the value of total_pages_ behind a mutex.
void set_total_pages(int total) {
SVAutoLock lock(&general_mutex_);
total_pages_ = total;
}
void set_memory_used(inT64 memory_used) {
SVAutoLock lock(&general_mutex_);
memory_used_ = memory_used;
}
// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_
// starting at index pages_offset_.
bool ReCachePages(); bool ReCachePages();
private: private:
...@@ -239,43 +297,77 @@ class DocumentData { ...@@ -239,43 +297,77 @@ class DocumentData {
inT64 max_memory_; inT64 max_memory_;
// Saved reader from LoadDocument to allow re-caching. // Saved reader from LoadDocument to allow re-caching.
FileReader reader_; FileReader reader_;
// Mutex that protects pages_ and pages_offset_ against multiple parallel
// loads, and provides a wait for page.
SVMutex pages_mutex_;
// Mutex that protects other data members that callers want to access without
// waiting for a load operation.
mutable SVMutex general_mutex_;
}; };
// A collection of DocumentData that knows roughly how much memory it is using. // A collection of DocumentData that knows roughly how much memory it is using.
// Note that while it supports background read-ahead, it assumes that a single
// thread is accessing documents, ie it is not safe for multiple threads to
// access different documents in parallel, as one may de-cache the other's
// content.
class DocumentCache { class DocumentCache {
public: public:
explicit DocumentCache(inT64 max_memory); explicit DocumentCache(inT64 max_memory);
~DocumentCache(); ~DocumentCache();
// Deletes all existing documents from the cache.
void Clear() {
documents_.clear();
num_pages_per_doc_ = 0;
}
// Adds all the documents in the list of filenames, counting memory. // Adds all the documents in the list of filenames, counting memory.
// The reader is used to read the files. // The reader is used to read the files.
bool LoadDocuments(const GenericVector<STRING>& filenames, const char* lang, bool LoadDocuments(const GenericVector<STRING>& filenames, const char* lang,
FileReader reader); CachingStrategy cache_strategy, FileReader reader);
// Adds document to the cache, throwing out other documents if needed. // Adds document to the cache.
bool AddToCache(DocumentData* data); bool AddToCache(DocumentData* data);
// Finds and returns a document by name. // Finds and returns a document by name.
DocumentData* FindDocument(const STRING& document_name) const; DocumentData* FindDocument(const STRING& document_name) const;
// Returns a page by serial number, selecting them in a round-robin fashion // Returns a page by serial number using the current cache_strategy_ to
// from all the documents. // determine the mapping from serial number to page.
const ImageData* GetPageBySerial(int serial); const ImageData* GetPageBySerial(int serial) {
if (cache_strategy_ == CS_SEQUENTIAL)
return GetPageSequential(serial);
else
return GetPageRoundRobin(serial);
}
const PointerVector<DocumentData>& documents() const { const PointerVector<DocumentData>& documents() const {
return documents_; return documents_;
} }
int total_pages() const { // Returns the total number of pages in an epoch. For CS_ROUND_ROBIN cache
return total_pages_; // strategy, could take a long time.
} int TotalPages();
private: private:
// Returns a page by serial number, selecting them in a round-robin fashion
// from all the documents. Highly disk-intensive, but doesn't need samples
// to be shuffled between files to begin with.
const ImageData* GetPageRoundRobin(int serial);
// Returns a page by serial number, selecting them in sequence from each file.
// Requires the samples to be shuffled between the files to give a random or
// uniform distribution of data. Less disk-intensive than GetPageRoundRobin.
const ImageData* GetPageSequential(int serial);
// Helper counts the number of adjacent cached neighbour documents_ of index
// looking in direction dir, ie index+dir, index+2*dir etc.
int CountNeighbourDocs(int index, int dir);
// A group of pages that corresponds in some loose way to a document. // A group of pages that corresponds in some loose way to a document.
PointerVector<DocumentData> documents_; PointerVector<DocumentData> documents_;
// Total of all pages. // Strategy to use for caching and serializing data samples.
int total_pages_; CachingStrategy cache_strategy_;
// Total of all memory used by the cache. // Number of pages in the first document, used as a divisor in
inT64 memory_used_; // GetPageSequential to determine the document index.
int num_pages_per_doc_;
// Max memory allowed in this cache. // Max memory allowed in this cache.
inT64 max_memory_; inT64 max_memory_;
}; };
......
/* -*-C-*- /* -*-C-*-
****************************************************************************** ******************************************************************************
* File: matrix.h (Formerly matrix.h)
* Description: Generic 2-d array/matrix and banded triangular matrix class.
* Author: Ray Smith
* TODO(rays) Separate from ratings matrix, which it also contains:
* *
* File: matrix.h (Formerly matrix.h) * Descrition: Ratings matrix class (specialization of banded matrix).
* Description: Ratings matrix code. (Used by associator) * Segmentation search matrix of lists of BLOB_CHOICE.
* Author: Mark Seaman, OCR Technology * Author: Mark Seaman, OCR Technology
* Created: Wed May 16 13:22:06 1990 * Created: Wed May 16 13:22:06 1990
* Modified: Tue Mar 19 16:00:20 1991 (Mark Seaman) marks@hpgrlt * Modified: Tue Mar 19 16:00:20 1991 (Mark Seaman) marks@hpgrlt
...@@ -25,9 +29,13 @@ ...@@ -25,9 +29,13 @@
#ifndef TESSERACT_CCSTRUCT_MATRIX_H__ #ifndef TESSERACT_CCSTRUCT_MATRIX_H__
#define TESSERACT_CCSTRUCT_MATRIX_H__ #define TESSERACT_CCSTRUCT_MATRIX_H__
#include <math.h>
#include "kdpair.h" #include "kdpair.h"
#include "points.h"
#include "serialis.h"
#include "unicharset.h" #include "unicharset.h"
class BLOB_CHOICE;
class BLOB_CHOICE_LIST; class BLOB_CHOICE_LIST;
#define NOT_CLASSIFIED reinterpret_cast<BLOB_CHOICE_LIST*>(0) #define NOT_CLASSIFIED reinterpret_cast<BLOB_CHOICE_LIST*>(0)
...@@ -44,34 +52,60 @@ class GENERIC_2D_ARRAY { ...@@ -44,34 +52,60 @@ class GENERIC_2D_ARRAY {
// either pass the memory in, or allocate after by calling Resize(). // either pass the memory in, or allocate after by calling Resize().
GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty, T* array) GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty, T* array)
: empty_(empty), dim1_(dim1), dim2_(dim2), array_(array) { : empty_(empty), dim1_(dim1), dim2_(dim2), array_(array) {
size_allocated_ = dim1 * dim2;
} }
// Original constructor for a full rectangular matrix DOES allocate memory // Original constructor for a full rectangular matrix DOES allocate memory
// and initialize it to empty. // and initialize it to empty.
GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty) GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty)
: empty_(empty), dim1_(dim1), dim2_(dim2) { : empty_(empty), dim1_(dim1), dim2_(dim2) {
array_ = new T[dim1_ * dim2_]; int new_size = dim1 * dim2;
for (int x = 0; x < dim1_; x++) array_ = new T[new_size];
for (int y = 0; y < dim2_; y++) size_allocated_ = new_size;
this->put(x, y, empty_); for (int i = 0; i < size_allocated_; ++i)
array_[i] = empty_;
}
// Default constructor for array allocation. Use Resize to set the size.
GENERIC_2D_ARRAY()
: array_(NULL), empty_(static_cast<T>(0)), dim1_(0), dim2_(0),
size_allocated_(0) {
}
GENERIC_2D_ARRAY(const GENERIC_2D_ARRAY<T>& src)
: array_(NULL), empty_(static_cast<T>(0)), dim1_(0), dim2_(0),
size_allocated_(0) {
*this = src;
} }
virtual ~GENERIC_2D_ARRAY() { delete[] array_; } virtual ~GENERIC_2D_ARRAY() { delete[] array_; }
void operator=(const GENERIC_2D_ARRAY<T>& src) {
ResizeNoInit(src.dim1(), src.dim2());
memcpy(array_, src.array_, num_elements() * sizeof(array_[0]));
}
// Reallocate the array to the given size. Does not keep old data, but does
// not initialize the array either.
void ResizeNoInit(int size1, int size2) {
int new_size = size1 * size2;
if (new_size > size_allocated_) {
delete [] array_;
array_ = new T[new_size];
size_allocated_ = new_size;
}
dim1_ = size1;
dim2_ = size2;
}
// Reallocate the array to the given size. Does not keep old data. // Reallocate the array to the given size. Does not keep old data.
void Resize(int size1, int size2, const T& empty) { void Resize(int size1, int size2, const T& empty) {
empty_ = empty; empty_ = empty;
if (size1 != dim1_ || size2 != dim2_) { ResizeNoInit(size1, size2);
dim1_ = size1;
dim2_ = size2;
delete [] array_;
array_ = new T[dim1_ * dim2_];
}
Clear(); Clear();
} }
// Reallocate the array to the given size, keeping old data. // Reallocate the array to the given size, keeping old data.
void ResizeWithCopy(int size1, int size2) { void ResizeWithCopy(int size1, int size2) {
if (size1 != dim1_ || size2 != dim2_) { if (size1 != dim1_ || size2 != dim2_) {
T* new_array = new T[size1 * size2]; int new_size = size1 * size2;
T* new_array = new T[new_size];
for (int col = 0; col < size1; ++col) { for (int col = 0; col < size1; ++col) {
for (int row = 0; row < size2; ++row) { for (int row = 0; row < size2; ++row) {
int old_index = col * dim2() + row; int old_index = col * dim2() + row;
...@@ -87,6 +121,7 @@ class GENERIC_2D_ARRAY { ...@@ -87,6 +121,7 @@ class GENERIC_2D_ARRAY {
array_ = new_array; array_ = new_array;
dim1_ = size1; dim1_ = size1;
dim2_ = size2; dim2_ = size2;
size_allocated_ = new_size;
} }
} }
...@@ -106,9 +141,16 @@ class GENERIC_2D_ARRAY { ...@@ -106,9 +141,16 @@ class GENERIC_2D_ARRAY {
if (fwrite(array_, sizeof(*array_), size, fp) != size) return false; if (fwrite(array_, sizeof(*array_), size, fp) != size) return false;
return true; return true;
} }
bool Serialize(tesseract::TFile* fp) const {
if (!SerializeSize(fp)) return false;
if (fp->FWrite(&empty_, sizeof(empty_), 1) != 1) return false;
int size = num_elements();
if (fp->FWrite(array_, sizeof(*array_), size) != size) return false;
return true;
}
// Reads from the given file. Returns false in case of error. // Reads from the given file. Returns false in case of error.
// Only works with bitwise-serializeable typ // Only works with bitwise-serializeable types!
// If swap is true, assumes a big/little-endian swap is needed. // If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, FILE* fp) { bool DeSerialize(bool swap, FILE* fp) {
if (!DeSerializeSize(swap, fp)) return false; if (!DeSerializeSize(swap, fp)) return false;
...@@ -122,6 +164,18 @@ class GENERIC_2D_ARRAY { ...@@ -122,6 +164,18 @@ class GENERIC_2D_ARRAY {
} }
return true; return true;
} }
bool DeSerialize(bool swap, tesseract::TFile* fp) {
if (!DeSerializeSize(swap, fp)) return false;
if (fp->FRead(&empty_, sizeof(empty_), 1) != 1) return false;
if (swap) ReverseN(&empty_, sizeof(empty_));
int size = num_elements();
if (fp->FRead(array_, sizeof(*array_), size) != size) return false;
if (swap) {
for (int i = 0; i < size; ++i)
ReverseN(&array_[i], sizeof(array_[i]));
}
return true;
}
// Writes to the given file. Returns false in case of error. // Writes to the given file. Returns false in case of error.
// Assumes a T::Serialize(FILE*) const function. // Assumes a T::Serialize(FILE*) const function.
...@@ -163,11 +217,17 @@ class GENERIC_2D_ARRAY { ...@@ -163,11 +217,17 @@ class GENERIC_2D_ARRAY {
} }
// Put a list element into the matrix at a specific location. // Put a list element into the matrix at a specific location.
void put(ICOORD pos, const T& thing) {
array_[this->index(pos.x(), pos.y())] = thing;
}
void put(int column, int row, const T& thing) { void put(int column, int row, const T& thing) {
array_[this->index(column, row)] = thing; array_[this->index(column, row)] = thing;
} }
// Get the item at a specified location from the matrix. // Get the item at a specified location from the matrix.
T get(ICOORD pos) const {
return array_[this->index(pos.x(), pos.y())];
}
T get(int column, int row) const { T get(int column, int row) const {
return array_[this->index(column, row)]; return array_[this->index(column, row)];
} }
...@@ -187,6 +247,207 @@ class GENERIC_2D_ARRAY { ...@@ -187,6 +247,207 @@ class GENERIC_2D_ARRAY {
return &array_[this->index(column, 0)]; return &array_[this->index(column, 0)];
} }
// Adds addend to *this, element-by-element.
void operator+=(const GENERIC_2D_ARRAY<T>& addend) {
if (dim2_ == addend.dim2_) {
// Faster if equal size in the major dimension.
int size = MIN(num_elements(), addend.num_elements());
for (int i = 0; i < size; ++i) {
array_[i] += addend.array_[i];
}
} else {
for (int x = 0; x < dim1_; x++) {
for (int y = 0; y < dim2_; y++) {
(*this)(x, y) += addend(x, y);
}
}
}
}
// Subtracts minuend from *this, element-by-element.
void operator-=(const GENERIC_2D_ARRAY<T>& minuend) {
if (dim2_ == minuend.dim2_) {
// Faster if equal size in the major dimension.
int size = MIN(num_elements(), minuend.num_elements());
for (int i = 0; i < size; ++i) {
array_[i] -= minuend.array_[i];
}
} else {
for (int x = 0; x < dim1_; x++) {
for (int y = 0; y < dim2_; y++) {
(*this)(x, y) -= minuend(x, y);
}
}
}
}
// Adds addend to all elements.
void operator+=(const T& addend) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] += addend;
}
}
// Multiplies *this by factor, element-by-element.
void operator*=(const T& factor) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] *= factor;
}
}
// Clips *this to the given range.
void Clip(const T& rangemin, const T& rangemax) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] = ClipToRange(array_[i], rangemin, rangemax);
}
}
// Returns true if all elements of *this are within the given range.
// Only uses operator<
bool WithinBounds(const T& rangemin, const T& rangemax) const {
int size = num_elements();
for (int i = 0; i < size; ++i) {
const T& value = array_[i];
if (value < rangemin || rangemax < value)
return false;
}
return true;
}
// Normalize the whole array.
double Normalize() {
int size = num_elements();
if (size <= 0) return 0.0;
// Compute the mean.
double mean = 0.0;
for (int i = 0; i < size; ++i) {
mean += array_[i];
}
mean /= size;
// Subtract the mean and compute the standard deviation.
double sd = 0.0;
for (int i = 0; i < size; ++i) {
double normed = array_[i] - mean;
array_[i] = normed;
sd += normed * normed;
}
sd = sqrt(sd / size);
if (sd > 0.0) {
// Divide by the sd.
for (int i = 0; i < size; ++i) {
array_[i] /= sd;
}
}
return sd;
}
// Returns the maximum value of the array.
T Max() const {
int size = num_elements();
if (size <= 0) return empty_;
// Compute the max.
T max_value = array_[0];
for (int i = 1; i < size; ++i) {
const T& value = array_[i];
if (value > max_value) max_value = value;
}
return max_value;
}
// Returns the maximum absolute value of the array.
T MaxAbs() const {
int size = num_elements();
if (size <= 0) return empty_;
// Compute the max.
T max_abs = static_cast<T>(0);
for (int i = 0; i < size; ++i) {
T value = static_cast<T>(fabs(array_[i]));
if (value > max_abs) max_abs = value;
}
return max_abs;
}
// Accumulates the element-wise sums of squares of src into *this.
void SumSquares(const GENERIC_2D_ARRAY<T>& src) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] += src.array_[i] * src.array_[i];
}
}
// Scales each element using the ada-grad algorithm, ie array_[i] by
// sqrt(num_samples/max(1,sqsum[i])).
void AdaGradScaling(const GENERIC_2D_ARRAY<T>& sqsum, int num_samples) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] *= sqrt(num_samples / MAX(1.0, sqsum.array_[i]));
}
}
void AssertFinite() const {
int size = num_elements();
for (int i = 0; i < size; ++i) {
ASSERT_HOST(isfinite(array_[i]));
}
}
// REGARDLESS OF THE CURRENT DIMENSIONS, treats the data as a
// num_dims-dimensional array/tensor with dimensions given by dims, (ordered
// from most significant to least significant, the same as standard C arrays)
// and moves src_dim to dest_dim, with the initial dest_dim and any dimensions
// in between shifted towards the hole left by src_dim. Example:
// Current data content: array_=[0, 1, 2, ....119]
// perhaps *this may be of dim[40, 3], with values [[0, 1, 2][3, 4, 5]...
// but the current dimensions are irrelevant.
// num_dims = 4, dims=[5, 4, 3, 2]
// src_dim=3, dest_dim=1
// tensor=[[[[0, 1][2, 3][4, 5]]
// [[6, 7][8, 9][10, 11]]
// [[12, 13][14, 15][16, 17]]
// [[18, 19][20, 21][22, 23]]]
// [[[24, 25]...
// output dims =[5, 2, 4, 3]
// output tensor=[[[[0, 2, 4][6, 8, 10][12, 14, 16][18, 20, 22]]
// [[1, 3, 5][7, 9, 11][13, 15, 17][19, 21, 23]]]
// [[[24, 26, 28]...
// which is stored in the array_ as:
// [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 1, 3, 5, 7, 9, 11, 13...]
// NOTE: the 2 stored matrix dimensions are simply copied from *this. To
// change the dimensions after the transpose, use ResizeNoInit.
// Higher dimensions above 2 are strictly the responsibility of the caller.
void RotatingTranspose(const int* dims, int num_dims, int src_dim,
int dest_dim, GENERIC_2D_ARRAY<T>* result) const {
int max_d = MAX(src_dim, dest_dim);
int min_d = MIN(src_dim, dest_dim);
// In a tensor of shape [d0, d1... min_d, ... max_d, ... dn-2, dn-1], the
// ends outside of min_d and max_d are unaffected, with [max_d +1, dn-1]
// being contiguous blocks of data that will move together, and
// [d0, min_d -1] being replicas of the transpose operation.
// num_replicas represents the large dimensions unchanged by the operation.
// move_size represents the small dimensions unchanged by the operation.
// src_step represents the stride in the src between each adjacent group
// in the destination.
int num_replicas = 1, move_size = 1, src_step = 1;
for (int d = 0; d < min_d; ++d) num_replicas *= dims[d];
for (int d = max_d + 1; d < num_dims; ++d) move_size *= dims[d];
for (int d = src_dim + 1; d < num_dims; ++d) src_step *= dims[d];
if (src_dim > dest_dim) src_step *= dims[src_dim];
// wrap_size is the size of a single replica, being the amount that is
// handled num_replicas times.
int wrap_size = move_size;
for (int d = min_d; d <= max_d; ++d) wrap_size *= dims[d];
result->ResizeNoInit(dim1_, dim2_);
result->empty_ = empty_;
const T* src = array_;
T* dest = result->array_;
for (int replica = 0; replica < num_replicas; ++replica) {
for (int start = 0; start < src_step; start += move_size) {
for (int pos = start; pos < wrap_size; pos += src_step) {
memcpy(dest, src + pos, sizeof(*dest) * move_size);
dest += move_size;
}
}
src += wrap_size;
}
}
// Delete objects pointed to by array_[i]. // Delete objects pointed to by array_[i].
void delete_matrix_pointers() { void delete_matrix_pointers() {
int size = num_elements(); int size = num_elements();
...@@ -206,6 +467,13 @@ class GENERIC_2D_ARRAY { ...@@ -206,6 +467,13 @@ class GENERIC_2D_ARRAY {
if (fwrite(&size, sizeof(size), 1, fp) != 1) return false; if (fwrite(&size, sizeof(size), 1, fp) != 1) return false;
return true; return true;
} }
bool SerializeSize(tesseract::TFile* fp) const {
inT32 size = dim1_;
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
size = dim2_;
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
return true;
}
// Factored helper to deserialize the size. // Factored helper to deserialize the size.
// If swap is true, assumes a big/little-endian swap is needed. // If swap is true, assumes a big/little-endian swap is needed.
bool DeSerializeSize(bool swap, FILE* fp) { bool DeSerializeSize(bool swap, FILE* fp) {
...@@ -219,11 +487,26 @@ class GENERIC_2D_ARRAY { ...@@ -219,11 +487,26 @@ class GENERIC_2D_ARRAY {
Resize(size1, size2, empty_); Resize(size1, size2, empty_);
return true; return true;
} }
bool DeSerializeSize(bool swap, tesseract::TFile* fp) {
inT32 size1, size2;
if (fp->FRead(&size1, sizeof(size1), 1) != 1) return false;
if (fp->FRead(&size2, sizeof(size2), 1) != 1) return false;
if (swap) {
ReverseN(&size1, sizeof(size1));
ReverseN(&size2, sizeof(size2));
}
Resize(size1, size2, empty_);
return true;
}
T* array_; T* array_;
T empty_; // The unused cell. T empty_; // The unused cell.
int dim1_; // Size of the 1st dimension in indexing functions. int dim1_; // Size of the 1st dimension in indexing functions.
int dim2_; // Size of the 2nd dimension in indexing functions. int dim2_; // Size of the 2nd dimension in indexing functions.
// The total size to which the array can be expanded before a realloc is
// needed. If Resize is used, memory is retained so it can be re-expanded
// without a further alloc, and this stores the allocated size.
int size_allocated_;
}; };
// A generic class to store a banded triangular matrix with entries of type T. // A generic class to store a banded triangular matrix with entries of type T.
......
...@@ -304,6 +304,7 @@ bool WERD_RES::SetupForRecognition(const UNICHARSET& unicharset_in, ...@@ -304,6 +304,7 @@ bool WERD_RES::SetupForRecognition(const UNICHARSET& unicharset_in,
tesseract = tess; tesseract = tess;
POLY_BLOCK* pb = block != NULL ? block->poly_block() : NULL; POLY_BLOCK* pb = block != NULL ? block->poly_block() : NULL;
if ((norm_mode_hint != tesseract::OEM_CUBE_ONLY && if ((norm_mode_hint != tesseract::OEM_CUBE_ONLY &&
norm_mode_hint != tesseract::OEM_LSTM_ONLY &&
word->cblob_list()->empty()) || (pb != NULL && !pb->IsText())) { word->cblob_list()->empty()) || (pb != NULL && !pb->IsText())) {
// Empty words occur when all the blobs have been moved to the rej_blobs // Empty words occur when all the blobs have been moved to the rej_blobs
// list, which seems to occur frequently in junk. // list, which seems to occur frequently in junk.
...@@ -882,17 +883,17 @@ void WERD_RES::FakeClassifyWord(int blob_count, BLOB_CHOICE** choices) { ...@@ -882,17 +883,17 @@ void WERD_RES::FakeClassifyWord(int blob_count, BLOB_CHOICE** choices) {
choice_it.add_after_then_move(choices[c]); choice_it.add_after_then_move(choices[c]);
ratings->put(c, c, choice_list); ratings->put(c, c, choice_list);
} }
FakeWordFromRatings(); FakeWordFromRatings(TOP_CHOICE_PERM);
reject_map.initialise(blob_count); reject_map.initialise(blob_count);
done = true; done = true;
} }
// Creates a WERD_CHOICE for the word using the top choices from the leading // Creates a WERD_CHOICE for the word using the top choices from the leading
// diagonal of the ratings matrix. // diagonal of the ratings matrix.
void WERD_RES::FakeWordFromRatings() { void WERD_RES::FakeWordFromRatings(PermuterType permuter) {
int num_blobs = ratings->dimension(); int num_blobs = ratings->dimension();
WERD_CHOICE* word_choice = new WERD_CHOICE(uch_set, num_blobs); WERD_CHOICE* word_choice = new WERD_CHOICE(uch_set, num_blobs);
word_choice->set_permuter(TOP_CHOICE_PERM); word_choice->set_permuter(permuter);
for (int b = 0; b < num_blobs; ++b) { for (int b = 0; b < num_blobs; ++b) {
UNICHAR_ID unichar_id = UNICHAR_SPACE; UNICHAR_ID unichar_id = UNICHAR_SPACE;
float rating = MAX_INT32; float rating = MAX_INT32;
...@@ -1105,6 +1106,7 @@ void WERD_RES::InitNonPointers() { ...@@ -1105,6 +1106,7 @@ void WERD_RES::InitNonPointers() {
x_height = 0.0; x_height = 0.0;
caps_height = 0.0; caps_height = 0.0;
baseline_shift = 0.0f; baseline_shift = 0.0f;
space_certainty = 0.0f;
guessed_x_ht = TRUE; guessed_x_ht = TRUE;
guessed_caps_ht = TRUE; guessed_caps_ht = TRUE;
combination = FALSE; combination = FALSE;
......
...@@ -295,6 +295,9 @@ class WERD_RES : public ELIST_LINK { ...@@ -295,6 +295,9 @@ class WERD_RES : public ELIST_LINK {
float x_height; // post match estimate float x_height; // post match estimate
float caps_height; // post match estimate float caps_height; // post match estimate
float baseline_shift; // post match estimate. float baseline_shift; // post match estimate.
// Certainty score for the spaces either side of this word (LSTM mode).
// MIN this value with the actual word certainty.
float space_certainty;
/* /*
To deal with fuzzy spaces we need to be able to combine "words" to form To deal with fuzzy spaces we need to be able to combine "words" to form
...@@ -590,7 +593,7 @@ class WERD_RES : public ELIST_LINK { ...@@ -590,7 +593,7 @@ class WERD_RES : public ELIST_LINK {
// Creates a WERD_CHOICE for the word using the top choices from the leading // Creates a WERD_CHOICE for the word using the top choices from the leading
// diagonal of the ratings matrix. // diagonal of the ratings matrix.
void FakeWordFromRatings(); void FakeWordFromRatings(PermuterType permuter);
// Copies the best_choice strings to the correct_text for adaption/training. // Copies the best_choice strings to the correct_text for adaption/training.
void BestChoiceToCorrectText(); void BestChoiceToCorrectText();
......
...@@ -257,13 +257,21 @@ enum OcrEngineMode { ...@@ -257,13 +257,21 @@ enum OcrEngineMode {
OEM_TESSERACT_ONLY, // Run Tesseract only - fastest OEM_TESSERACT_ONLY, // Run Tesseract only - fastest
OEM_CUBE_ONLY, // Run Cube only - better accuracy, but slower OEM_CUBE_ONLY, // Run Cube only - better accuracy, but slower
OEM_TESSERACT_CUBE_COMBINED, // Run both and combine results - best accuracy OEM_TESSERACT_CUBE_COMBINED, // Run both and combine results - best accuracy
OEM_DEFAULT // Specify this mode when calling init_*(), OEM_DEFAULT, // Specify this mode when calling init_*(),
// to indicate that any of the above modes // to indicate that any of the above modes
// should be automatically inferred from the // should be automatically inferred from the
// variables in the language-specific config, // variables in the language-specific config,
// command-line configs, or if not specified // command-line configs, or if not specified
// in any of the above should be set to the // in any of the above should be set to the
// default OEM_TESSERACT_ONLY. // default OEM_TESSERACT_ONLY.
// OEM_LSTM_ONLY will fall back (with a warning) to OEM_TESSERACT_ONLY where
// there is no network model available. This allows use of a mix of languages,
// some of which contain a network model, and some of which do not. Since the
// tesseract model is required for the LSTM to fall back to for "difficult"
// words anyway, this seems like a reasonable approach, but leaves the danger
// of not noticing that it is using the wrong engine if the warning is
// ignored.
OEM_LSTM_ONLY, // Run just the LSTM line recognizer.
}; };
} // namespace tesseract. } // namespace tesseract.
......
...@@ -14,7 +14,7 @@ endif ...@@ -14,7 +14,7 @@ endif
include_HEADERS = \ include_HEADERS = \
basedir.h errcode.h fileerr.h genericvector.h helpers.h host.h memry.h \ basedir.h errcode.h fileerr.h genericvector.h helpers.h host.h memry.h \
ndminx.h params.h ocrclass.h platform.h serialis.h strngs.h \ ndminx.h params.h ocrclass.h platform.h serialis.h strngs.h \
tesscallback.h unichar.h unicharmap.h unicharset.h tesscallback.h unichar.h unicharcompress.h unicharmap.h unicharset.h
noinst_HEADERS = \ noinst_HEADERS = \
ambigs.h bits16.h bitvector.h ccutil.h clst.h doubleptr.h elst2.h \ ambigs.h bits16.h bitvector.h ccutil.h clst.h doubleptr.h elst2.h \
...@@ -38,7 +38,7 @@ libtesseract_ccutil_la_SOURCES = \ ...@@ -38,7 +38,7 @@ libtesseract_ccutil_la_SOURCES = \
mainblk.cpp memry.cpp \ mainblk.cpp memry.cpp \
serialis.cpp strngs.cpp scanutils.cpp \ serialis.cpp strngs.cpp scanutils.cpp \
tessdatamanager.cpp tprintf.cpp \ tessdatamanager.cpp tprintf.cpp \
unichar.cpp unicharmap.cpp unicharset.cpp unicodes.cpp \ unichar.cpp unicharcompress.cpp unicharmap.cpp unicharset.cpp unicodes.cpp \
params.cpp universalambigs.cpp params.cpp universalambigs.cpp
if T_WIN if T_WIN
......
...@@ -108,6 +108,8 @@ class GenericHeap { ...@@ -108,6 +108,8 @@ class GenericHeap {
const Pair& PeekTop() const { const Pair& PeekTop() const {
return heap_[0]; return heap_[0];
} }
// Get the value of the worst (largest, defined by operator< ) element.
const Pair& PeekWorst() const { return heap_[IndexOfWorst()]; }
// Removes the top element of the heap. If entry is not NULL, the element // Removes the top element of the heap. If entry is not NULL, the element
// is copied into *entry, otherwise it is discarded. // is copied into *entry, otherwise it is discarded.
...@@ -136,22 +138,12 @@ class GenericHeap { ...@@ -136,22 +138,12 @@ class GenericHeap {
// not NULL, the element is copied into *entry, otherwise it is discarded. // not NULL, the element is copied into *entry, otherwise it is discarded.
// Time = O(n). Returns false if the heap was already empty. // Time = O(n). Returns false if the heap was already empty.
bool PopWorst(Pair* entry) { bool PopWorst(Pair* entry) {
int heap_size = heap_.size(); int worst_index = IndexOfWorst();
if (heap_size == 0) return false; // It cannot be empty! if (worst_index < 0) return false; // It cannot be empty!
// Find the maximum element. Its index is guaranteed to be greater than
// the index of the parent of the last element, since by the heap invariant
// the parent must be less than or equal to the children.
int worst_index = heap_size - 1;
int end_parent = ParentNode(worst_index);
for (int i = worst_index - 1; i > end_parent; --i) {
if (heap_[worst_index] < heap_[i])
worst_index = i;
}
// Extract the worst element from the heap, leaving a hole at worst_index. // Extract the worst element from the heap, leaving a hole at worst_index.
if (entry != NULL) if (entry != NULL)
*entry = heap_[worst_index]; *entry = heap_[worst_index];
--heap_size; int heap_size = heap_.size() - 1;
if (heap_size > 0) { if (heap_size > 0) {
// Sift the hole upwards to match the last element of the heap_ // Sift the hole upwards to match the last element of the heap_
Pair hole_pair = heap_[heap_size]; Pair hole_pair = heap_[heap_size];
...@@ -162,6 +154,22 @@ class GenericHeap { ...@@ -162,6 +154,22 @@ class GenericHeap {
return true; return true;
} }
// Returns the index of the worst element. Time = O(n/2).
int IndexOfWorst() const {
int heap_size = heap_.size();
if (heap_size == 0) return -1; // It cannot be empty!
// Find the maximum element. Its index is guaranteed to be greater than
// the index of the parent of the last element, since by the heap invariant
// the parent must be less than or equal to the children.
int worst_index = heap_size - 1;
int end_parent = ParentNode(worst_index);
for (int i = worst_index - 1; i > end_parent; --i) {
if (heap_[worst_index] < heap_[i]) worst_index = i;
}
return worst_index;
}
// The pointed-to Pair has changed its key value, so the location of pair // The pointed-to Pair has changed its key value, so the location of pair
// is reshuffled to maintain the heap invariant. // is reshuffled to maintain the heap invariant.
// Must be a valid pointer to an element of the heap_! // Must be a valid pointer to an element of the heap_!
......
...@@ -174,6 +174,8 @@ class GenericVector { ...@@ -174,6 +174,8 @@ class GenericVector {
// If swap is true, assumes a big/little-endian swap is needed. // If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, FILE* fp); bool DeSerialize(bool swap, FILE* fp);
bool DeSerialize(bool swap, tesseract::TFile* fp); bool DeSerialize(bool swap, tesseract::TFile* fp);
// Skips the deserialization of the vector.
static bool SkipDeSerialize(bool swap, tesseract::TFile* fp);
// Writes a vector of classes to the given file. Assumes the existence of // Writes a vector of classes to the given file. Assumes the existence of
// bool T::Serialize(FILE* fp) const that returns false in case of error. // bool T::Serialize(FILE* fp) const that returns false in case of error.
// Returns false in case of error. // Returns false in case of error.
...@@ -186,6 +188,8 @@ class GenericVector { ...@@ -186,6 +188,8 @@ class GenericVector {
// If swap is true, assumes a big/little-endian swap is needed. // If swap is true, assumes a big/little-endian swap is needed.
bool DeSerializeClasses(bool swap, FILE* fp); bool DeSerializeClasses(bool swap, FILE* fp);
bool DeSerializeClasses(bool swap, tesseract::TFile* fp); bool DeSerializeClasses(bool swap, tesseract::TFile* fp);
// Calls SkipDeSerialize on the elements of the vector.
static bool SkipDeSerializeClasses(bool swap, tesseract::TFile* fp);
// Allocates a new array of double the current_size, copies over the // Allocates a new array of double the current_size, copies over the
// information from data to the new location, deletes data and returns // information from data to the new location, deletes data and returns
...@@ -238,14 +242,13 @@ class GenericVector { ...@@ -238,14 +242,13 @@ class GenericVector {
int binary_search(const T& target) const { int binary_search(const T& target) const {
int bottom = 0; int bottom = 0;
int top = size_used_; int top = size_used_;
do { while (top - bottom > 1) {
int middle = (bottom + top) / 2; int middle = (bottom + top) / 2;
if (data_[middle] > target) if (data_[middle] > target)
top = middle; top = middle;
else else
bottom = middle; bottom = middle;
} }
while (top - bottom > 1);
return bottom; return bottom;
} }
...@@ -361,7 +364,7 @@ inline bool LoadDataFromFile(const STRING& filename, ...@@ -361,7 +364,7 @@ inline bool LoadDataFromFile(const STRING& filename,
size_t size = ftell(fp); size_t size = ftell(fp);
fseek(fp, 0, SEEK_SET); fseek(fp, 0, SEEK_SET);
// Pad with a 0, just in case we treat the result as a string. // Pad with a 0, just in case we treat the result as a string.
data->init_to_size((int)size + 1, 0); data->init_to_size(static_cast<int>(size) + 1, 0);
bool result = fread(&(*data)[0], 1, size, fp) == size; bool result = fread(&(*data)[0], 1, size, fp) == size;
fclose(fp); fclose(fp);
return result; return result;
...@@ -556,34 +559,54 @@ class PointerVector : public GenericVector<T*> { ...@@ -556,34 +559,54 @@ class PointerVector : public GenericVector<T*> {
} }
bool DeSerialize(bool swap, TFile* fp) { bool DeSerialize(bool swap, TFile* fp) {
inT32 reserved; inT32 reserved;
if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false; if (!DeSerializeSize(swap, fp, &reserved)) return false;
if (swap) Reverse32(&reserved);
GenericVector<T*>::reserve(reserved); GenericVector<T*>::reserve(reserved);
truncate(0); truncate(0);
for (int i = 0; i < reserved; ++i) { for (int i = 0; i < reserved; ++i) {
inT8 non_null; if (!DeSerializeElement(swap, fp)) return false;
if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false; }
T* item = NULL; return true;
if (non_null) { }
item = new T; // Enables deserialization of a selection of elements. Note that in order to
if (!item->DeSerialize(swap, fp)) { // retain the integrity of the stream, the caller must call some combination
delete item; // of DeSerializeElement and DeSerializeSkip of the exact number returned in
return false; // *size, assuming a true return.
} static bool DeSerializeSize(bool swap, TFile* fp, inT32* size) {
this->push_back(item); if (fp->FRead(size, sizeof(*size), 1) != 1) return false;
} else { if (swap) Reverse32(size);
// Null elements should keep their place in the vector. return true;
this->push_back(NULL); }
// Reads and appends to the vector the next element of the serialization.
bool DeSerializeElement(bool swap, TFile* fp) {
inT8 non_null;
if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false;
T* item = NULL;
if (non_null) {
item = new T;
if (!item->DeSerialize(swap, fp)) {
delete item;
return false;
} }
this->push_back(item);
} else {
// Null elements should keep their place in the vector.
this->push_back(NULL);
}
return true;
}
// Skips the next element of the serialization.
static bool DeSerializeSkip(bool swap, TFile* fp) {
inT8 non_null;
if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false;
if (non_null) {
if (!T::SkipDeSerialize(swap, fp)) return false;
} }
return true; return true;
} }
// Sorts the items pointed to by the members of this vector using // Sorts the items pointed to by the members of this vector using
// t::operator<(). // t::operator<().
void sort() { void sort() { this->GenericVector<T*>::sort(&sort_ptr_cmp<T>); }
sort(&sort_ptr_cmp<T>);
}
}; };
} // namespace tesseract } // namespace tesseract
...@@ -926,6 +949,13 @@ bool GenericVector<T>::DeSerialize(bool swap, tesseract::TFile* fp) { ...@@ -926,6 +949,13 @@ bool GenericVector<T>::DeSerialize(bool swap, tesseract::TFile* fp) {
} }
return true; return true;
} }
template <typename T>
bool GenericVector<T>::SkipDeSerialize(bool swap, tesseract::TFile* fp) {
inT32 reserved;
if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false;
if (swap) Reverse32(&reserved);
return fp->FRead(NULL, sizeof(T), reserved) == reserved;
}
// Writes a vector of classes to the given file. Assumes the existence of // Writes a vector of classes to the given file. Assumes the existence of
// bool T::Serialize(FILE* fp) const that returns false in case of error. // bool T::Serialize(FILE* fp) const that returns false in case of error.
...@@ -976,6 +1006,16 @@ bool GenericVector<T>::DeSerializeClasses(bool swap, tesseract::TFile* fp) { ...@@ -976,6 +1006,16 @@ bool GenericVector<T>::DeSerializeClasses(bool swap, tesseract::TFile* fp) {
} }
return true; return true;
} }
template <typename T>
bool GenericVector<T>::SkipDeSerializeClasses(bool swap, tesseract::TFile* fp) {
uinT32 reserved;
if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false;
if (swap) Reverse32(&reserved);
for (int i = 0; i < reserved; ++i) {
if (!T::SkipDeSerialize(swap, fp)) return false;
}
return true;
}
// This method clear the current object, then, does a shallow copy of // This method clear the current object, then, does a shallow copy of
// its argument, and finally invalidates its argument. // its argument, and finally invalidates its argument.
......
...@@ -95,7 +95,7 @@ int TFile::FRead(void* buffer, int size, int count) { ...@@ -95,7 +95,7 @@ int TFile::FRead(void* buffer, int size, int count) {
char* char_buffer = reinterpret_cast<char*>(buffer); char* char_buffer = reinterpret_cast<char*>(buffer);
if (data_->size() - offset_ < required_size) if (data_->size() - offset_ < required_size)
required_size = data_->size() - offset_; required_size = data_->size() - offset_;
if (required_size > 0) if (required_size > 0 && char_buffer != NULL)
memcpy(char_buffer, &(*data_)[offset_], required_size); memcpy(char_buffer, &(*data_)[offset_], required_size);
offset_ += required_size; offset_ += required_size;
return required_size / size; return required_size / size;
......
...@@ -181,6 +181,14 @@ bool STRING::DeSerialize(bool swap, TFile* fp) { ...@@ -181,6 +181,14 @@ bool STRING::DeSerialize(bool swap, TFile* fp) {
return true; return true;
} }
// As DeSerialize, but only seeks past the data - hence a static method.
bool STRING::SkipDeSerialize(bool swap, tesseract::TFile* fp) {
inT32 len;
if (fp->FRead(&len, sizeof(len), 1) != 1) return false;
if (swap) ReverseN(&len, sizeof(len));
return fp->FRead(NULL, 1, len) == len;
}
BOOL8 STRING::contains(const char c) const { BOOL8 STRING::contains(const char c) const {
return (c != '\0') && (strchr (GetCStr(), c) != NULL); return (c != '\0') && (strchr (GetCStr(), c) != NULL);
} }
......
...@@ -60,6 +60,8 @@ class TESS_API STRING ...@@ -60,6 +60,8 @@ class TESS_API STRING
// Reads from the given file. Returns false in case of error. // Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed. // If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, tesseract::TFile* fp); bool DeSerialize(bool swap, tesseract::TFile* fp);
// As DeSerialize, but only seeks past the data - hence a static method.
static bool SkipDeSerialize(bool swap, tesseract::TFile* fp);
BOOL8 contains(const char c) const; BOOL8 contains(const char c) const;
inT32 length() const; inT32 length() const;
......
...@@ -47,6 +47,10 @@ static const char kShapeTableFileSuffix[] = "shapetable"; ...@@ -47,6 +47,10 @@ static const char kShapeTableFileSuffix[] = "shapetable";
static const char kBigramDawgFileSuffix[] = "bigram-dawg"; static const char kBigramDawgFileSuffix[] = "bigram-dawg";
static const char kUnambigDawgFileSuffix[] = "unambig-dawg"; static const char kUnambigDawgFileSuffix[] = "unambig-dawg";
static const char kParamsModelFileSuffix[] = "params-model"; static const char kParamsModelFileSuffix[] = "params-model";
static const char kLSTMModelFileSuffix[] = "lstm";
static const char kLSTMPuncDawgFileSuffix[] = "lstm-punc-dawg";
static const char kLSTMSystemDawgFileSuffix[] = "lstm-word-dawg";
static const char kLSTMNumberDawgFileSuffix[] = "lstm-number-dawg";
namespace tesseract { namespace tesseract {
...@@ -68,6 +72,10 @@ enum TessdataType { ...@@ -68,6 +72,10 @@ enum TessdataType {
TESSDATA_BIGRAM_DAWG, // 14 TESSDATA_BIGRAM_DAWG, // 14
TESSDATA_UNAMBIG_DAWG, // 15 TESSDATA_UNAMBIG_DAWG, // 15
TESSDATA_PARAMS_MODEL, // 16 TESSDATA_PARAMS_MODEL, // 16
TESSDATA_LSTM, // 17
TESSDATA_LSTM_PUNC_DAWG, // 18
TESSDATA_LSTM_SYSTEM_DAWG, // 19
TESSDATA_LSTM_NUMBER_DAWG, // 20
TESSDATA_NUM_ENTRIES TESSDATA_NUM_ENTRIES
}; };
...@@ -76,24 +84,28 @@ enum TessdataType { ...@@ -76,24 +84,28 @@ enum TessdataType {
* kTessdataFileSuffixes[i] indicates the file suffix for * kTessdataFileSuffixes[i] indicates the file suffix for
* tessdata of type i (from TessdataType enum). * tessdata of type i (from TessdataType enum).
*/ */
static const char * const kTessdataFileSuffixes[] = { static const char *const kTessdataFileSuffixes[] = {
kLangConfigFileSuffix, // 0 kLangConfigFileSuffix, // 0
kUnicharsetFileSuffix, // 1 kUnicharsetFileSuffix, // 1
kAmbigsFileSuffix, // 2 kAmbigsFileSuffix, // 2
kBuiltInTemplatesFileSuffix, // 3 kBuiltInTemplatesFileSuffix, // 3
kBuiltInCutoffsFileSuffix, // 4 kBuiltInCutoffsFileSuffix, // 4
kNormProtoFileSuffix, // 5 kNormProtoFileSuffix, // 5
kPuncDawgFileSuffix, // 6 kPuncDawgFileSuffix, // 6
kSystemDawgFileSuffix, // 7 kSystemDawgFileSuffix, // 7
kNumberDawgFileSuffix, // 8 kNumberDawgFileSuffix, // 8
kFreqDawgFileSuffix, // 9 kFreqDawgFileSuffix, // 9
kFixedLengthDawgsFileSuffix, // 10 // deprecated kFixedLengthDawgsFileSuffix, // 10 // deprecated
kCubeUnicharsetFileSuffix, // 11 kCubeUnicharsetFileSuffix, // 11
kCubeSystemDawgFileSuffix, // 12 kCubeSystemDawgFileSuffix, // 12
kShapeTableFileSuffix, // 13 kShapeTableFileSuffix, // 13
kBigramDawgFileSuffix, // 14 kBigramDawgFileSuffix, // 14
kUnambigDawgFileSuffix, // 15 kUnambigDawgFileSuffix, // 15
kParamsModelFileSuffix, // 16 kParamsModelFileSuffix, // 16
kLSTMModelFileSuffix, // 17
kLSTMPuncDawgFileSuffix, // 18
kLSTMSystemDawgFileSuffix, // 19
kLSTMNumberDawgFileSuffix, // 20
}; };
/** /**
...@@ -101,23 +113,27 @@ static const char * const kTessdataFileSuffixes[] = { ...@@ -101,23 +113,27 @@ static const char * const kTessdataFileSuffixes[] = {
* of type i (from TessdataType enum) is text, and is binary otherwise. * of type i (from TessdataType enum) is text, and is binary otherwise.
*/ */
static const bool kTessdataFileIsText[] = { static const bool kTessdataFileIsText[] = {
true, // 0 true, // 0
true, // 1 true, // 1
true, // 2 true, // 2
false, // 3 false, // 3
true, // 4 true, // 4
true, // 5 true, // 5
false, // 6 false, // 6
false, // 7 false, // 7
false, // 8 false, // 8
false, // 9 false, // 9
false, // 10 // deprecated false, // 10 // deprecated
true, // 11 true, // 11
false, // 12 false, // 12
false, // 13 false, // 13
false, // 14 false, // 14
false, // 15 false, // 15
true, // 16 true, // 16
false, // 17
false, // 18
false, // 19
false, // 20
}; };
/** /**
......
///////////////////////////////////////////////////////////////////////
// File: unicharcompress.cpp
// Description: Unicode re-encoding using a sequence of smaller numbers in
// place of a single large code for CJK, similarly for Indic,
// and dissection of ligatures for other scripts.
// Author: Ray Smith
// Created: Wed Mar 04 14:45:01 PST 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
///////////////////////////////////////////////////////////////////////
#include "unicharcompress.h"
#include "tprintf.h"
namespace tesseract {
// String used to represent the null_id in direct_set.
const char* kNullChar = "<nul>";
// Local struct used only for processing the radical-stroke table.
struct RadicalStroke {
RadicalStroke() : num_strokes(0) {}
RadicalStroke(const STRING& r, int s) : radical(r), num_strokes(s) {}
bool operator==(const RadicalStroke& other) const {
return radical == other.radical && num_strokes == other.num_strokes;
}
// The radical is encoded as a string because its format is of an int with
// an optional ' mark to indicate a simplified shape. To treat these as
// distinct, we use a string and a UNICHARSET to do the integer mapping.
STRING radical;
// The number of strokes we treat as dense and just take the face value from
// the table.
int num_strokes;
};
// Hash functor for RadicalStroke.
struct RadicalStrokedHash {
size_t operator()(const RadicalStroke& rs) const {
size_t result = rs.num_strokes;
for (int i = 0; i < rs.radical.length(); ++i) {
result ^= rs.radical[i] << (6 * i + 8);
}
return result;
}
};
// A hash map to convert unicodes to radical,stroke pair.
typedef TessHashMap<int, RadicalStroke> RSMap;
// A hash map to count occurrences of each radical,stroke pair.
typedef TessHashMap<RadicalStroke, int, RadicalStrokedHash> RSCounts;
// Helper function builds the RSMap from the radical-stroke file, which has
// already been read into a STRING. Returns false on error.
// The radical_stroke_table is non-const because it gets split and the caller
// is unlikely to want to use it again.
static bool DecodeRadicalStrokeTable(STRING* radical_stroke_table,
RSMap* radical_map) {
GenericVector<STRING> lines;
radical_stroke_table->split('\n', &lines);
for (int i = 0; i < lines.size(); ++i) {
if (lines[i].length() == 0 || lines[i][0] == '#') continue;
int unicode, radical, strokes;
STRING str_radical;
if (sscanf(lines[i].string(), "%x\t%d.%d", &unicode, &radical, &strokes) ==
3) {
str_radical.add_str_int("", radical);
} else if (sscanf(lines[i].string(), "%x\t%d'.%d", &unicode, &radical,
&strokes) == 3) {
str_radical.add_str_int("'", radical);
} else {
tprintf("Invalid format in radical stroke table at line %d: %s\n", i,
lines[i].string());
return false;
}
(*radical_map)[unicode] = RadicalStroke(str_radical, strokes);
}
return true;
}
UnicharCompress::UnicharCompress() : code_range_(0) {}
UnicharCompress::UnicharCompress(const UnicharCompress& src) { *this = src; }
UnicharCompress::~UnicharCompress() { Cleanup(); }
UnicharCompress& UnicharCompress::operator=(const UnicharCompress& src) {
Cleanup();
encoder_ = src.encoder_;
code_range_ = src.code_range_;
SetupDecoder();
return *this;
}
// Computes the encoding for the given unicharset. It is a requirement that
// the file training/langdata/radical-stroke.txt have been read into the
// input string radical_stroke_table.
// Returns false if the encoding cannot be constructed.
bool UnicharCompress::ComputeEncoding(const UNICHARSET& unicharset, int null_id,
STRING* radical_stroke_table) {
RSMap radical_map;
if (!DecodeRadicalStrokeTable(radical_stroke_table, &radical_map))
return false;
encoder_.clear();
UNICHARSET direct_set;
UNICHARSET radicals;
// To avoid unused codes, clear the special codes from the unicharsets.
direct_set.clear();
radicals.clear();
// Always keep space as 0;
direct_set.unichar_insert(" ");
// Null char is next if we have one.
if (null_id >= 0) {
direct_set.unichar_insert(kNullChar);
}
RSCounts radical_counts;
// In the initial map, codes [0, unicharset.size()) are
// reserved for non-han/hangul sequences of 1 or more unicodes.
int hangul_offset = unicharset.size();
// Hangul takes the next range [hangul_offset, hangul_offset + kTotalJamos).
const int kTotalJamos = kLCount + kVCount + kTCount;
// Han takes the codes beyond hangul_offset + kTotalJamos. Since it is hard
// to measure the number of radicals and strokes, initially we use the same
// code range for all 3 Han code positions, and fix them after.
int han_offset = hangul_offset + kTotalJamos;
int max_num_strokes = -1;
for (int u = 0; u <= unicharset.size(); ++u) {
bool self_normalized = false;
// We special-case allow null_id to be equal to unicharset.size() in case
// there is no space in unicharset for it.
if (u == unicharset.size()) {
if (u == null_id) {
self_normalized = true;
} else {
break; // Finished.
}
} else {
self_normalized = strcmp(unicharset.id_to_unichar(u),
unicharset.get_normed_unichar(u)) == 0;
}
RecodedCharID code;
// Convert to unicodes.
GenericVector<int> unicodes;
if (u < unicharset.size() &&
UNICHAR::UTF8ToUnicode(unicharset.get_normed_unichar(u), &unicodes) &&
unicodes.size() == 1) {
// Check single unicodes for Hangul/Han and encode if so.
int unicode = unicodes[0];
int leading, vowel, trailing;
auto it = radical_map.find(unicode);
if (it != radical_map.end()) {
// This is Han. Convert to radical, stroke, index.
if (!radicals.contains_unichar(it->second.radical.string())) {
radicals.unichar_insert(it->second.radical.string());
}
int radical = radicals.unichar_to_id(it->second.radical.string());
int num_strokes = it->second.num_strokes;
int num_samples = radical_counts[it->second]++;
if (num_strokes > max_num_strokes) max_num_strokes = num_strokes;
code.Set3(radical + han_offset, num_strokes + han_offset,
num_samples + han_offset);
} else if (DecomposeHangul(unicode, &leading, &vowel, &trailing)) {
// This is Hangul. Since we know the exact size of each part at compile
// time, it gets the bottom set of codes.
code.Set3(leading + hangul_offset, vowel + kLCount + hangul_offset,
trailing + kLCount + kVCount + hangul_offset);
}
}
// If the code is still empty, it wasn't Han or Hangul.
if (code.length() == 0) {
// Special cases.
if (u == UNICHAR_SPACE) {
code.Set(0, 0); // Space.
} else if (u == null_id || (unicharset.has_special_codes() &&
u < SPECIAL_UNICHAR_CODES_COUNT)) {
code.Set(0, direct_set.unichar_to_id(kNullChar));
} else {
// Add the direct_set unichar-ids of the unicodes in sequence to the
// code.
for (int i = 0; i < unicodes.size(); ++i) {
int position = code.length();
if (position >= RecodedCharID::kMaxCodeLen) {
tprintf("Unichar %d=%s->%s is too long to encode!!\n", u,
unicharset.id_to_unichar(u),
unicharset.get_normed_unichar(u));
return false;
}
int uni = unicodes[i];
UNICHAR unichar(uni);
char* utf8 = unichar.utf8_str();
if (!direct_set.contains_unichar(utf8))
direct_set.unichar_insert(utf8);
code.Set(position, direct_set.unichar_to_id(utf8));
delete[] utf8;
if (direct_set.size() > unicharset.size()) {
// Code space got bigger!
tprintf("Code space expanded from original unicharset!!\n");
return false;
}
}
}
}
code.set_self_normalized(self_normalized);
encoder_.push_back(code);
}
// Now renumber Han to make all codes unique. We already added han_offset to
// all Han. Now separate out the radical, stroke, and count codes for Han.
// In the uniqued Han encoding, the 1st code uses the next radical_map.size()
// values, the 2nd code uses the next max_num_strokes+1 values, and the 3rd
// code uses the rest for the max number of duplicated radical/stroke combos.
int num_radicals = radicals.size();
for (int u = 0; u < unicharset.size(); ++u) {
RecodedCharID* code = &encoder_[u];
if ((*code)(0) >= han_offset) {
code->Set(1, (*code)(1) + num_radicals);
code->Set(2, (*code)(2) + num_radicals + max_num_strokes + 1);
}
}
DefragmentCodeValues(null_id >= 0 ? 1 : -1);
SetupDecoder();
return true;
}
// Sets up an encoder that doesn't change the unichars at all, so it just
// passes them through unchanged.
void UnicharCompress::SetupPassThrough(const UNICHARSET& unicharset) {
GenericVector<RecodedCharID> codes;
for (int u = 0; u < unicharset.size(); ++u) {
RecodedCharID code;
code.Set(0, u);
codes.push_back(code);
}
SetupDirect(codes);
}
// Sets up an encoder directly using the given encoding vector, which maps
// unichar_ids to the given codes.
void UnicharCompress::SetupDirect(const GenericVector<RecodedCharID>& codes) {
encoder_ = codes;
ComputeCodeRange();
SetupDecoder();
}
// Renumbers codes to eliminate unused values.
void UnicharCompress::DefragmentCodeValues(int encoded_null) {
// There may not be any Hangul, but even if there is, it is possible that not
// all codes are used. Likewise with the Han encoding, it is possible that not
// all numbers of strokes are used.
ComputeCodeRange();
GenericVector<int> offsets;
offsets.init_to_size(code_range_, 0);
// Find which codes are used
for (int c = 0; c < encoder_.size(); ++c) {
const RecodedCharID& code = encoder_[c];
for (int i = 0; i < code.length(); ++i) {
offsets[code(i)] = 1;
}
}
// Compute offsets based on code use.
int offset = 0;
for (int i = 0; i < offsets.size(); ++i) {
// If not used, decrement everything above here.
// We are moving encoded_null to the end, so it is not "used".
if (offsets[i] == 0 || i == encoded_null) {
--offset;
} else {
offsets[i] = offset;
}
}
if (encoded_null >= 0) {
// The encoded_null is moving to the end, for the benefit of TensorFlow,
// which is offsets.size() + offsets.back().
offsets[encoded_null] = offsets.size() + offsets.back() - encoded_null;
}
// Now apply the offsets.
for (int c = 0; c < encoder_.size(); ++c) {
RecodedCharID* code = &encoder_[c];
for (int i = 0; i < code->length(); ++i) {
int value = (*code)(i);
code->Set(i, value + offsets[value]);
}
}
ComputeCodeRange();
}
// Encodes a single unichar_id. Returns the length of the code, or zero if
// invalid input, and the encoding itself
int UnicharCompress::EncodeUnichar(int unichar_id, RecodedCharID* code) const {
if (unichar_id < 0 || unichar_id >= encoder_.size()) return 0;
*code = encoder_[unichar_id];
return code->length();
}
// Decodes code, returning the original unichar-id, or
// INVALID_UNICHAR_ID if the input is invalid.
int UnicharCompress::DecodeUnichar(const RecodedCharID& code) const {
int len = code.length();
if (len <= 0 || len > RecodedCharID::kMaxCodeLen) return INVALID_UNICHAR_ID;
auto it = decoder_.find(code);
if (it == decoder_.end()) return INVALID_UNICHAR_ID;
return it->second;
}
// Writes to the given file. Returns false in case of error.
bool UnicharCompress::Serialize(TFile* fp) const {
return encoder_.SerializeClasses(fp);
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool UnicharCompress::DeSerialize(bool swap, TFile* fp) {
if (!encoder_.DeSerializeClasses(swap, fp)) return false;
ComputeCodeRange();
SetupDecoder();
return true;
}
// Returns a STRING containing a text file that describes the encoding thus:
// <index>[,<index>]*<tab><UTF8-str><newline>
// In words, a comma-separated list of one or more indices, followed by a tab
// and the UTF-8 string that the code represents per line. Most simple scripts
// will encode a single index to a UTF8-string, but Chinese, Japanese, Korean
// and the Indic scripts will contain a many-to-many mapping.
// See the class comment above for details.
STRING UnicharCompress::GetEncodingAsString(
const UNICHARSET& unicharset) const {
STRING encoding;
for (int c = 0; c < encoder_.size(); ++c) {
const RecodedCharID& code = encoder_[c];
if (0 < c && c < SPECIAL_UNICHAR_CODES_COUNT && code == encoder_[c - 1]) {
// Don't show the duplicate entry.
continue;
}
encoding.add_str_int("", code(0));
for (int i = 1; i < code.length(); ++i) {
encoding.add_str_int(",", code(i));
}
encoding += "\t";
if (c >= unicharset.size() || (0 < c && c < SPECIAL_UNICHAR_CODES_COUNT &&
unicharset.has_special_codes())) {
encoding += kNullChar;
} else {
encoding += unicharset.id_to_unichar(c);
}
encoding += "\n";
}
return encoding;
}
// Helper decomposes a Hangul unicode to 3 parts, leading, vowel, trailing.
// Note that the returned values are 0-based indices, NOT unicode Jamo.
// Returns false if the input is not in the Hangul unicode range.
/* static */
bool UnicharCompress::DecomposeHangul(int unicode, int* leading, int* vowel,
int* trailing) {
if (unicode < kFirstHangul) return false;
int offset = unicode - kFirstHangul;
if (offset >= kNumHangul) return false;
const int kNCount = kVCount * kTCount;
*leading = offset / kNCount;
*vowel = (offset % kNCount) / kTCount;
*trailing = offset % kTCount;
return true;
}
// Computes the value of code_range_ from the encoder_.
void UnicharCompress::ComputeCodeRange() {
code_range_ = -1;
for (int c = 0; c < encoder_.size(); ++c) {
const RecodedCharID& code = encoder_[c];
for (int i = 0; i < code.length(); ++i) {
if (code(i) > code_range_) code_range_ = code(i);
}
}
++code_range_;
}
// Initializes the decoding hash_map from the encoding array.
void UnicharCompress::SetupDecoder() {
Cleanup();
is_valid_start_.init_to_size(code_range_, false);
for (int c = 0; c < encoder_.size(); ++c) {
const RecodedCharID& code = encoder_[c];
if (code.self_normalized() || decoder_.find(code) == decoder_.end())
decoder_[code] = c;
is_valid_start_[code(0)] = true;
RecodedCharID prefix = code;
int len = code.length() - 1;
prefix.Truncate(len);
auto final_it = final_codes_.find(prefix);
if (final_it == final_codes_.end()) {
GenericVectorEqEq<int>* code_list = new GenericVectorEqEq<int>;
code_list->push_back(code(len));
final_codes_[prefix] = code_list;
while (--len >= 0) {
prefix.Truncate(len);
auto next_it = next_codes_.find(prefix);
if (next_it == next_codes_.end()) {
GenericVectorEqEq<int>* code_list = new GenericVectorEqEq<int>;
code_list->push_back(code(len));
next_codes_[prefix] = code_list;
} else {
// We still have to search the list as we may get here via multiple
// lengths of code.
if (!next_it->second->contains(code(len)))
next_it->second->push_back(code(len));
break; // This prefix has been processed.
}
}
} else {
if (!final_it->second->contains(code(len)))
final_it->second->push_back(code(len));
}
}
}
// Frees allocated memory.
void UnicharCompress::Cleanup() {
decoder_.clear();
is_valid_start_.clear();
for (auto it = next_codes_.begin(); it != next_codes_.end(); ++it) {
delete it->second;
}
for (auto it = final_codes_.begin(); it != final_codes_.end(); ++it) {
delete it->second;
}
next_codes_.clear();
final_codes_.clear();
}
} // namespace tesseract.
///////////////////////////////////////////////////////////////////////
// File: unicharcompress.h
// Description: Unicode re-encoding using a sequence of smaller numbers in
// place of a single large code for CJK, similarly for Indic,
// and dissection of ligatures for other scripts.
// Author: Ray Smith
// Created: Wed Mar 04 14:45:01 PST 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_CCUTIL_UNICHARCOMPRESS_H_
#define TESSERACT_CCUTIL_UNICHARCOMPRESS_H_
#include "hashfn.h"
#include "serialis.h"
#include "strngs.h"
#include "unicharset.h"
namespace tesseract {
// Trivial class to hold the code for a recoded unichar-id.
class RecodedCharID {
public:
// The maximum length of a code.
static const int kMaxCodeLen = 9;
RecodedCharID() : self_normalized_(0), length_(0) {
memset(code_, 0, sizeof(code_));
}
void Truncate(int length) { length_ = length; }
// Sets the code value at the given index in the code.
void Set(int index, int value) {
code_[index] = value;
if (length_ <= index) length_ = index + 1;
}
// Shorthand for setting codes of length 3, as all Hangul and Han codes are
// length 3.
void Set3(int code0, int code1, int code2) {
length_ = 3;
code_[0] = code0;
code_[1] = code1;
code_[2] = code2;
}
// Accessors
bool self_normalized() const { return self_normalized_ != 0; }
void set_self_normalized(bool value) { self_normalized_ = value; }
int length() const { return length_; }
int operator()(int index) const { return code_[index]; }
// Writes to the given file. Returns false in case of error.
bool Serialize(TFile* fp) const {
if (fp->FWrite(&self_normalized_, sizeof(self_normalized_), 1) != 1)
return false;
if (fp->FWrite(&length_, sizeof(length_), 1) != 1) return false;
if (fp->FWrite(code_, sizeof(code_[0]), length_) != length_) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, TFile* fp) {
if (fp->FRead(&self_normalized_, sizeof(self_normalized_), 1) != 1)
return false;
if (fp->FRead(&length_, sizeof(length_), 1) != 1) return false;
if (swap) ReverseN(&length_, sizeof(length_));
if (fp->FRead(code_, sizeof(code_[0]), length_) != length_) return false;
if (swap) {
for (int i = 0; i < length_; ++i) {
ReverseN(&code_[i], sizeof(code_[i]));
}
}
return true;
}
bool operator==(const RecodedCharID& other) const {
if (length_ != other.length_) return false;
for (int i = 0; i < length_; ++i) {
if (code_[i] != other.code_[i]) return false;
}
return true;
}
// Hash functor for RecodedCharID.
struct RecodedCharIDHash {
size_t operator()(const RecodedCharID& code) const {
size_t result = 0;
for (int i = 0; i < code.length_; ++i) {
result ^= code(i) << (7 * i);
}
return result;
}
};
private:
// True if this code is self-normalizing, ie is the master entry for indices
// that map to the same code. Has boolean value, but inT8 for serialization.
inT8 self_normalized_;
// The number of elements in use in code_;
inT32 length_;
// The re-encoded form of the unichar-id to which this RecodedCharID relates.
inT32 code_[kMaxCodeLen];
};
// Class holds a "compression" of a unicharset to simplify the learning problem
// for a neural-network-based classifier.
// Objectives:
// 1 (CJK): Ids of a unicharset with a large number of classes are expressed as
// a sequence of 3 codes with much fewer values.
// This is achieved using the Jamo coding for Hangul and the Unicode
// Radical-Stroke-index for Han.
// 2 (Indic): Instead of thousands of codes with one for each grapheme, re-code
// as the unicode sequence (but coded in a more compact space).
// 3 (the rest): Eliminate multi-path problems with ligatures and fold confusing
// and not significantly distinct shapes (quotes) togther, ie
// represent the fi ligature as the f-i pair, and fold u+2019 and
// friends all onto ascii single '
// 4 The null character and mapping to target activations:
// To save horizontal coding space, the compressed codes are generally mapped
// to target network activations without intervening null characters, BUT
// in the case of ligatures, such as ff, null characters have to be included
// so existence of repeated codes is detected at codebook-building time, and
// null characters are embedded directly into the codes, so the rest of the
// system doesn't need to worry about the problem (much). There is still an
// effect on the range of ways in which the target activations can be
// generated.
//
// The computed code values are compact (no unused values), and, for CJK,
// unique (each code position uses a disjoint set of values from each other code
// position). For non-CJK, the same code value CAN be used in multiple
// positions, eg the ff ligature is converted to <f> <nullchar> <f>, where <f>
// is the same code as is used for the single f.
// NOTE that an intended consequence of using the normalized text from the
// unicharset is that the fancy quotes all map to a single code, so round-trip
// conversion doesn't work for all unichar-ids.
class UnicharCompress {
public:
UnicharCompress();
UnicharCompress(const UnicharCompress& src);
~UnicharCompress();
UnicharCompress& operator=(const UnicharCompress& src);
// The 1st Hangul unicode.
static const int kFirstHangul = 0xac00;
// The number of Hangul unicodes.
static const int kNumHangul = 11172;
// The number of Jamos for each of the 3 parts of a Hangul character, being
// the Leading consonant, Vowel and Trailing consonant.
static const int kLCount = 19;
static const int kVCount = 21;
static const int kTCount = 28;
// Computes the encoding for the given unicharset. It is a requirement that
// the file training/langdata/radical-stroke.txt have been read into the
// input string radical_stroke_table.
// Returns false if the encoding cannot be constructed.
bool ComputeEncoding(const UNICHARSET& unicharset, int null_id,
STRING* radical_stroke_table);
// Sets up an encoder that doesn't change the unichars at all, so it just
// passes them through unchanged.
void SetupPassThrough(const UNICHARSET& unicharset);
// Sets up an encoder directly using the given encoding vector, which maps
// unichar_ids to the given codes.
void SetupDirect(const GenericVector<RecodedCharID>& codes);
// Returns the number of different values that can be used in a code, ie
// 1 + the maximum value that will ever be used by an RecodedCharID code in
// any position in its array.
int code_range() const { return code_range_; }
// Encodes a single unichar_id. Returns the length of the code, (or zero if
// invalid input), and the encoding itself in code.
int EncodeUnichar(int unichar_id, RecodedCharID* code) const;
// Decodes code, returning the original unichar-id, or
// INVALID_UNICHAR_ID if the input is invalid. Note that this is not a perfect
// inverse of EncodeUnichar, since the unichar-id of U+2019 (curly single
// quote), for example, will have the same encoding as the unichar-id of
// U+0027 (ascii '). The foldings are obtained from the input unicharset,
// which in turn obtains them from NormalizeUTF8String in normstrngs.cpp,
// and include NFKC normalization plus others like quote and dash folding.
int DecodeUnichar(const RecodedCharID& code) const;
// Returns true if the given code is a valid start or single code.
bool IsValidFirstCode(int code) const { return is_valid_start_[code]; }
// Returns a list of valid non-final next codes for a given prefix code,
// which may be empty.
const GenericVector<int>* GetNextCodes(const RecodedCharID& code) const {
auto it = next_codes_.find(code);
return it == next_codes_.end() ? NULL : it->second;
}
// Returns a list of valid final codes for a given prefix code, which may
// be empty.
const GenericVector<int>* GetFinalCodes(const RecodedCharID& code) const {
auto it = final_codes_.find(code);
return it == final_codes_.end() ? NULL : it->second;
}
// Writes to the given file. Returns false in case of error.
bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, TFile* fp);
// Returns a STRING containing a text file that describes the encoding thus:
// <index>[,<index>]*<tab><UTF8-str><newline>
// In words, a comma-separated list of one or more indices, followed by a tab
// and the UTF-8 string that the code represents per line. Most simple scripts
// will encode a single index to a UTF8-string, but Chinese, Japanese, Korean
// and the Indic scripts will contain a many-to-many mapping.
// See the class comment above for details.
STRING GetEncodingAsString(const UNICHARSET& unicharset) const;
// Helper decomposes a Hangul unicode to 3 parts, leading, vowel, trailing.
// Note that the returned values are 0-based indices, NOT unicode Jamo.
// Returns false if the input is not in the Hangul unicode range.
static bool DecomposeHangul(int unicode, int* leading, int* vowel,
int* trailing);
private:
// Renumbers codes to eliminate unused values.
void DefragmentCodeValues(int encoded_null);
// Computes the value of code_range_ from the encoder_.
void ComputeCodeRange();
// Initializes the decoding hash_map from the encoder_ array.
void SetupDecoder();
// Frees allocated memory.
void Cleanup();
// The encoder that maps a unichar-id to a sequence of small codes.
// encoder_ is the only part that is serialized. The rest is computed on load.
GenericVector<RecodedCharID> encoder_;
// Decoder converts the output of encoder back to a unichar-id.
TessHashMap<RecodedCharID, int, RecodedCharID::RecodedCharIDHash> decoder_;
// True if the index is a valid single or start code.
GenericVector<bool> is_valid_start_;
// Maps a prefix code to a list of valid next codes.
// The map owns the vectors.
TessHashMap<RecodedCharID, GenericVectorEqEq<int>*,
RecodedCharID::RecodedCharIDHash>
next_codes_;
// Maps a prefix code to a list of valid final codes.
// The map owns the vectors.
TessHashMap<RecodedCharID, GenericVectorEqEq<int>*,
RecodedCharID::RecodedCharIDHash>
final_codes_;
// Max of any value in encoder_ + 1.
int code_range_;
};
} // namespace tesseract.
#endif // TESSERACT_CCUTIL_UNICHARCOMPRESS_H_
...@@ -906,6 +906,8 @@ void UNICHARSET::post_load_setup() { ...@@ -906,6 +906,8 @@ void UNICHARSET::post_load_setup() {
han_sid_ = get_script_id_from_name("Han"); han_sid_ = get_script_id_from_name("Han");
hiragana_sid_ = get_script_id_from_name("Hiragana"); hiragana_sid_ = get_script_id_from_name("Hiragana");
katakana_sid_ = get_script_id_from_name("Katakana"); katakana_sid_ = get_script_id_from_name("Katakana");
thai_sid_ = get_script_id_from_name("Thai");
hangul_sid_ = get_script_id_from_name("Hangul");
// Compute default script. Use the highest-counting alpha script, that is // Compute default script. Use the highest-counting alpha script, that is
// not the common script, as that still contains some "alphas". // not the common script, as that still contains some "alphas".
......
...@@ -290,6 +290,8 @@ class UNICHARSET { ...@@ -290,6 +290,8 @@ class UNICHARSET {
han_sid_ = 0; han_sid_ = 0;
hiragana_sid_ = 0; hiragana_sid_ = 0;
katakana_sid_ = 0; katakana_sid_ = 0;
thai_sid_ = 0;
hangul_sid_ = 0;
} }
// Return the size of the set (the number of different UNICHAR it holds). // Return the size of the set (the number of different UNICHAR it holds).
...@@ -604,6 +606,16 @@ class UNICHARSET { ...@@ -604,6 +606,16 @@ class UNICHARSET {
return unichars[unichar_id].properties.AnyRangeEmpty(); return unichars[unichar_id].properties.AnyRangeEmpty();
} }
// Returns true if the script of the given id is space delimited.
// Returns false for Han and Thai scripts.
bool IsSpaceDelimited(UNICHAR_ID unichar_id) const {
if (INVALID_UNICHAR_ID == unichar_id) return true;
int script_id = get_script(unichar_id);
return script_id != han_sid_ && script_id != thai_sid_ &&
script_id != hangul_sid_ && script_id != hiragana_sid_ &&
script_id != katakana_sid_;
}
// Return the script name of the given unichar. // Return the script name of the given unichar.
// The returned pointer will always be the same for the same script, it's // The returned pointer will always be the same for the same script, it's
// managed by unicharset and thus MUST NOT be deleted // managed by unicharset and thus MUST NOT be deleted
...@@ -773,7 +785,7 @@ class UNICHARSET { ...@@ -773,7 +785,7 @@ class UNICHARSET {
// Returns normalized version of unichar with the given unichar_id. // Returns normalized version of unichar with the given unichar_id.
const char *get_normed_unichar(UNICHAR_ID unichar_id) const { const char *get_normed_unichar(UNICHAR_ID unichar_id) const {
if (unichar_id == UNICHAR_SPACE && has_special_codes()) return " "; if (unichar_id == UNICHAR_SPACE) return " ";
return unichars[unichar_id].properties.normed.string(); return unichars[unichar_id].properties.normed.string();
} }
// Returns a vector of UNICHAR_IDs that represent the ids of the normalized // Returns a vector of UNICHAR_IDs that represent the ids of the normalized
...@@ -835,6 +847,8 @@ class UNICHARSET { ...@@ -835,6 +847,8 @@ class UNICHARSET {
int han_sid() const { return han_sid_; } int han_sid() const { return han_sid_; }
int hiragana_sid() const { return hiragana_sid_; } int hiragana_sid() const { return hiragana_sid_; }
int katakana_sid() const { return katakana_sid_; } int katakana_sid() const { return katakana_sid_; }
int thai_sid() const { return thai_sid_; }
int hangul_sid() const { return hangul_sid_; }
int default_sid() const { return default_sid_; } int default_sid() const { return default_sid_; }
// Returns true if the unicharset has the concept of upper/lower case. // Returns true if the unicharset has the concept of upper/lower case.
...@@ -977,6 +991,8 @@ class UNICHARSET { ...@@ -977,6 +991,8 @@ class UNICHARSET {
int han_sid_; int han_sid_;
int hiragana_sid_; int hiragana_sid_;
int katakana_sid_; int katakana_sid_;
int thai_sid_;
int hangul_sid_;
// The most frequently occurring script in the charset. // The most frequently occurring script in the charset.
int default_sid_; int default_sid_;
}; };
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# Initialization # Initialization
# ---------------------------------------- # ----------------------------------------
AC_PREREQ([2.50]) AC_PREREQ([2.50])
AC_INIT([tesseract], [3.05.00dev], [https://github.com/tesseract-ocr/tesseract/issues]) AC_INIT([tesseract], [4.00.00dev], [https://github.com/tesseract-ocr/tesseract/issues])
AC_PROG_CXX([g++ clang++]) AC_PROG_CXX([g++ clang++])
AC_LANG([C++]) AC_LANG([C++])
AC_LANG_COMPILER_REQUIRE AC_LANG_COMPILER_REQUIRE
...@@ -18,8 +18,8 @@ AC_PREFIX_DEFAULT([/usr/local]) ...@@ -18,8 +18,8 @@ AC_PREFIX_DEFAULT([/usr/local])
# Define date of package, etc. Could be useful in auto-generated # Define date of package, etc. Could be useful in auto-generated
# documentation. # documentation.
PACKAGE_YEAR=2015 PACKAGE_YEAR=2016
PACKAGE_DATE="07/11" PACKAGE_DATE="11/11"
abs_top_srcdir=`AS_DIRNAME([$0])` abs_top_srcdir=`AS_DIRNAME([$0])`
gitrev="`git --git-dir=${abs_top_srcdir}/.git --work-tree=${abs_top_srcdir} describe --always --tags`" gitrev="`git --git-dir=${abs_top_srcdir}/.git --work-tree=${abs_top_srcdir} describe --always --tags`"
...@@ -42,8 +42,8 @@ AC_SUBST([PACKAGE_DATE]) ...@@ -42,8 +42,8 @@ AC_SUBST([PACKAGE_DATE])
GENERIC_LIBRARY_NAME=tesseract GENERIC_LIBRARY_NAME=tesseract
# Release versioning # Release versioning
GENERIC_MAJOR_VERSION=3 GENERIC_MAJOR_VERSION=4
GENERIC_MINOR_VERSION=4 GENERIC_MINOR_VERSION=0
GENERIC_MICRO_VERSION=0 GENERIC_MICRO_VERSION=0
# API version (often = GENERIC_MAJOR_VERSION.GENERIC_MINOR_VERSION) # API version (often = GENERIC_MAJOR_VERSION.GENERIC_MINOR_VERSION)
...@@ -520,6 +520,7 @@ fi ...@@ -520,6 +520,7 @@ fi
# Output files # Output files
AC_CONFIG_FILES([Makefile tesseract.pc]) AC_CONFIG_FILES([Makefile tesseract.pc])
AC_CONFIG_FILES([api/Makefile]) AC_CONFIG_FILES([api/Makefile])
AC_CONFIG_FILES([arch/Makefile])
AC_CONFIG_FILES([ccmain/Makefile]) AC_CONFIG_FILES([ccmain/Makefile])
AC_CONFIG_FILES([opencl/Makefile]) AC_CONFIG_FILES([opencl/Makefile])
AC_CONFIG_FILES([ccstruct/Makefile]) AC_CONFIG_FILES([ccstruct/Makefile])
...@@ -528,6 +529,7 @@ AC_CONFIG_FILES([classify/Makefile]) ...@@ -528,6 +529,7 @@ AC_CONFIG_FILES([classify/Makefile])
AC_CONFIG_FILES([cube/Makefile]) AC_CONFIG_FILES([cube/Makefile])
AC_CONFIG_FILES([cutil/Makefile]) AC_CONFIG_FILES([cutil/Makefile])
AC_CONFIG_FILES([dict/Makefile]) AC_CONFIG_FILES([dict/Makefile])
AC_CONFIG_FILES([lstm/Makefile])
AC_CONFIG_FILES([neural_networks/runtime/Makefile]) AC_CONFIG_FILES([neural_networks/runtime/Makefile])
AC_CONFIG_FILES([textord/Makefile]) AC_CONFIG_FILES([textord/Makefile])
AC_CONFIG_FILES([viewer/Makefile]) AC_CONFIG_FILES([viewer/Makefile])
......
...@@ -401,7 +401,6 @@ LIST s_adjoin(LIST var_list, void *variable, int_compare compare) { ...@@ -401,7 +401,6 @@ LIST s_adjoin(LIST var_list, void *variable, int_compare compare) {
return (push_last (var_list, variable)); return (push_last (var_list, variable));
} }
/********************************************************************** /**********************************************************************
* s e a r c h * s e a r c h
* *
......
...@@ -69,14 +69,17 @@ Dawg *DawgLoader::Load() { ...@@ -69,14 +69,17 @@ Dawg *DawgLoader::Load() {
PermuterType perm_type; PermuterType perm_type;
switch (tessdata_dawg_type_) { switch (tessdata_dawg_type_) {
case TESSDATA_PUNC_DAWG: case TESSDATA_PUNC_DAWG:
case TESSDATA_LSTM_PUNC_DAWG:
dawg_type = DAWG_TYPE_PUNCTUATION; dawg_type = DAWG_TYPE_PUNCTUATION;
perm_type = PUNC_PERM; perm_type = PUNC_PERM;
break; break;
case TESSDATA_SYSTEM_DAWG: case TESSDATA_SYSTEM_DAWG:
case TESSDATA_LSTM_SYSTEM_DAWG:
dawg_type = DAWG_TYPE_WORD; dawg_type = DAWG_TYPE_WORD;
perm_type = SYSTEM_DAWG_PERM; perm_type = SYSTEM_DAWG_PERM;
break; break;
case TESSDATA_NUMBER_DAWG: case TESSDATA_NUMBER_DAWG:
case TESSDATA_LSTM_NUMBER_DAWG:
dawg_type = DAWG_TYPE_NUMBER; dawg_type = DAWG_TYPE_NUMBER;
perm_type = NUMBER_PERM; perm_type = NUMBER_PERM;
break; break;
......
...@@ -202,10 +202,8 @@ DawgCache *Dict::GlobalDawgCache() { ...@@ -202,10 +202,8 @@ DawgCache *Dict::GlobalDawgCache() {
return &cache; return &cache;
} }
void Dict::Load(DawgCache *dawg_cache) { // Sets up ready for a Load or LoadLSTM.
STRING name; void Dict::SetupForLoad(DawgCache *dawg_cache) {
STRING &lang = getCCUtil()->lang;
if (dawgs_.length() != 0) this->End(); if (dawgs_.length() != 0) this->End();
apostrophe_unichar_id_ = getUnicharset().unichar_to_id(kApostropheSymbol); apostrophe_unichar_id_ = getUnicharset().unichar_to_id(kApostropheSymbol);
...@@ -220,10 +218,10 @@ void Dict::Load(DawgCache *dawg_cache) { ...@@ -220,10 +218,10 @@ void Dict::Load(DawgCache *dawg_cache) {
dawg_cache_ = new DawgCache(); dawg_cache_ = new DawgCache();
dawg_cache_is_ours_ = true; dawg_cache_is_ours_ = true;
} }
}
TessdataManager &tessdata_manager = getCCUtil()->tessdata_manager; // Loads the dawgs needed by Tesseract. Call FinishLoad() after.
const char *data_file_name = tessdata_manager.GetDataFileName().string(); void Dict::Load(const char *data_file_name, const STRING &lang) {
// Load dawgs_. // Load dawgs_.
if (load_punc_dawg) { if (load_punc_dawg) {
punc_dawg_ = dawg_cache_->GetSquishedDawg( punc_dawg_ = dawg_cache_->GetSquishedDawg(
...@@ -255,6 +253,7 @@ void Dict::Load(DawgCache *dawg_cache) { ...@@ -255,6 +253,7 @@ void Dict::Load(DawgCache *dawg_cache) {
if (unambig_dawg_) dawgs_ += unambig_dawg_; if (unambig_dawg_) dawgs_ += unambig_dawg_;
} }
STRING name;
if (((STRING &)user_words_suffix).length() > 0 || if (((STRING &)user_words_suffix).length() > 0 ||
((STRING &)user_words_file).length() > 0) { ((STRING &)user_words_file).length() > 0) {
Trie *trie_ptr = new Trie(DAWG_TYPE_WORD, lang, USER_DAWG_PERM, Trie *trie_ptr = new Trie(DAWG_TYPE_WORD, lang, USER_DAWG_PERM,
...@@ -300,8 +299,33 @@ void Dict::Load(DawgCache *dawg_cache) { ...@@ -300,8 +299,33 @@ void Dict::Load(DawgCache *dawg_cache) {
// This dawg is temporary and should not be searched by letter_is_ok. // This dawg is temporary and should not be searched by letter_is_ok.
pending_words_ = new Trie(DAWG_TYPE_WORD, lang, NO_PERM, pending_words_ = new Trie(DAWG_TYPE_WORD, lang, NO_PERM,
getUnicharset().size(), dawg_debug_level); getUnicharset().size(), dawg_debug_level);
}
// Construct a list of corresponding successors for each dawg. Each entry i // Loads the dawgs needed by the LSTM model. Call FinishLoad() after.
void Dict::LoadLSTM(const char *data_file_name, const STRING &lang) {
// Load dawgs_.
if (load_punc_dawg) {
punc_dawg_ = dawg_cache_->GetSquishedDawg(
lang, data_file_name, TESSDATA_LSTM_PUNC_DAWG, dawg_debug_level);
if (punc_dawg_) dawgs_ += punc_dawg_;
}
if (load_system_dawg) {
Dawg *system_dawg = dawg_cache_->GetSquishedDawg(
lang, data_file_name, TESSDATA_LSTM_SYSTEM_DAWG, dawg_debug_level);
if (system_dawg) dawgs_ += system_dawg;
}
if (load_number_dawg) {
Dawg *number_dawg = dawg_cache_->GetSquishedDawg(
lang, data_file_name, TESSDATA_LSTM_NUMBER_DAWG, dawg_debug_level);
if (number_dawg) dawgs_ += number_dawg;
}
}
// Completes the loading process after Load() and/or LoadLSTM().
// Returns false if no dictionaries were loaded.
bool Dict::FinishLoad() {
if (dawgs_.empty()) return false;
// Construct a list of corresponding successors for each dawg. Each entry, i,
// in the successors_ vector is a vector of integers that represent the // in the successors_ vector is a vector of integers that represent the
// indices into the dawgs_ vector of the successors for dawg i. // indices into the dawgs_ vector of the successors for dawg i.
successors_.reserve(dawgs_.length()); successors_.reserve(dawgs_.length());
...@@ -316,6 +340,7 @@ void Dict::Load(DawgCache *dawg_cache) { ...@@ -316,6 +340,7 @@ void Dict::Load(DawgCache *dawg_cache) {
} }
successors_ += lst; successors_ += lst;
} }
return true;
} }
void Dict::End() { void Dict::End() {
...@@ -368,6 +393,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args, ...@@ -368,6 +393,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
// Initialization. // Initialization.
PermuterType curr_perm = NO_PERM; PermuterType curr_perm = NO_PERM;
dawg_args->updated_dawgs->clear(); dawg_args->updated_dawgs->clear();
dawg_args->valid_end = false;
// Go over the active_dawgs vector and insert DawgPosition records // Go over the active_dawgs vector and insert DawgPosition records
// with the updated ref (an edge with the corresponding unichar id) into // with the updated ref (an edge with the corresponding unichar id) into
...@@ -405,6 +431,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args, ...@@ -405,6 +431,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
dawg_debug_level > 0, dawg_debug_level > 0,
"Append transition from punc dawg to current dawgs: "); "Append transition from punc dawg to current dawgs: ");
if (sdawg->permuter() > curr_perm) curr_perm = sdawg->permuter(); if (sdawg->permuter() > curr_perm) curr_perm = sdawg->permuter();
if (sdawg->end_of_word(dawg_edge) &&
punc_dawg->end_of_word(punc_transition_edge))
dawg_args->valid_end = true;
} }
} }
} }
...@@ -419,6 +448,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args, ...@@ -419,6 +448,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
dawg_debug_level > 0, dawg_debug_level > 0,
"Extend punctuation dawg: "); "Extend punctuation dawg: ");
if (PUNC_PERM > curr_perm) curr_perm = PUNC_PERM; if (PUNC_PERM > curr_perm) curr_perm = PUNC_PERM;
if (punc_dawg->end_of_word(punc_edge)) dawg_args->valid_end = true;
} }
continue; continue;
} }
...@@ -436,6 +466,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args, ...@@ -436,6 +466,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
dawg_debug_level > 0, dawg_debug_level > 0,
"Return to punctuation dawg: "); "Return to punctuation dawg: ");
if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter(); if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter();
if (punc_dawg->end_of_word(punc_edge)) dawg_args->valid_end = true;
} }
} }
...@@ -445,8 +476,8 @@ int Dict::def_letter_is_okay(void* void_dawg_args, ...@@ -445,8 +476,8 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
// possible edges, not only for the exact unichar_id, but also // possible edges, not only for the exact unichar_id, but also
// for all its character classes (alpha, digit, etc). // for all its character classes (alpha, digit, etc).
if (dawg->type() == DAWG_TYPE_PATTERN) { if (dawg->type() == DAWG_TYPE_PATTERN) {
ProcessPatternEdges(dawg, pos, unichar_id, word_end, ProcessPatternEdges(dawg, pos, unichar_id, word_end, dawg_args,
dawg_args->updated_dawgs, &curr_perm); &curr_perm);
// There can't be any successors to dawg that is of type // There can't be any successors to dawg that is of type
// DAWG_TYPE_PATTERN, so we are done examining this DawgPosition. // DAWG_TYPE_PATTERN, so we are done examining this DawgPosition.
continue; continue;
...@@ -473,6 +504,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args, ...@@ -473,6 +504,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
continue; continue;
} }
if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter(); if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter();
if (dawg->end_of_word(edge) &&
(punc_dawg == NULL || punc_dawg->end_of_word(pos.punc_ref)))
dawg_args->valid_end = true;
dawg_args->updated_dawgs->add_unique( dawg_args->updated_dawgs->add_unique(
DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref, DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref,
false), false),
...@@ -497,7 +531,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args, ...@@ -497,7 +531,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos, void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos,
UNICHAR_ID unichar_id, bool word_end, UNICHAR_ID unichar_id, bool word_end,
DawgPositionVector *updated_dawgs, DawgArgs *dawg_args,
PermuterType *curr_perm) const { PermuterType *curr_perm) const {
NODE_REF node = GetStartingNode(dawg, pos.dawg_ref); NODE_REF node = GetStartingNode(dawg, pos.dawg_ref);
// Try to find the edge corresponding to the exact unichar_id and to all the // Try to find the edge corresponding to the exact unichar_id and to all the
...@@ -520,7 +554,8 @@ void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos, ...@@ -520,7 +554,8 @@ void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos,
tprintf("Letter found in pattern dawg %d\n", pos.dawg_index); tprintf("Letter found in pattern dawg %d\n", pos.dawg_index);
} }
if (dawg->permuter() > *curr_perm) *curr_perm = dawg->permuter(); if (dawg->permuter() > *curr_perm) *curr_perm = dawg->permuter();
updated_dawgs->add_unique( if (dawg->end_of_word(edge)) dawg_args->valid_end = true;
dawg_args->updated_dawgs->add_unique(
DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref, DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref,
pos.back_to_punc), pos.back_to_punc),
dawg_debug_level > 0, dawg_debug_level > 0,
...@@ -816,5 +851,13 @@ bool Dict::valid_punctuation(const WERD_CHOICE &word) { ...@@ -816,5 +851,13 @@ bool Dict::valid_punctuation(const WERD_CHOICE &word) {
return false; return false;
} }
/// Returns true if the language is space-delimited (not CJ, or T).
bool Dict::IsSpaceDelimitedLang() const {
const UNICHARSET &u_set = getUnicharset();
if (u_set.han_sid() > 0) return false;
if (u_set.katakana_sid() > 0) return false;
if (u_set.thai_sid() > 0) return false;
return true;
}
} // namespace tesseract } // namespace tesseract
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "dawg.h" #include "dawg.h"
#include "dawg_cache.h" #include "dawg_cache.h"
#include "host.h" #include "host.h"
#include "oldlist.h"
#include "ratngs.h" #include "ratngs.h"
#include "stopper.h" #include "stopper.h"
#include "trie.h" #include "trie.h"
...@@ -76,11 +75,13 @@ enum XHeightConsistencyEnum {XH_GOOD, XH_SUBNORMAL, XH_INCONSISTENT}; ...@@ -76,11 +75,13 @@ enum XHeightConsistencyEnum {XH_GOOD, XH_SUBNORMAL, XH_INCONSISTENT};
struct DawgArgs { struct DawgArgs {
DawgArgs(DawgPositionVector *d, DawgPositionVector *up, PermuterType p) DawgArgs(DawgPositionVector *d, DawgPositionVector *up, PermuterType p)
: active_dawgs(d), updated_dawgs(up), permuter(p) {} : active_dawgs(d), updated_dawgs(up), permuter(p), valid_end(false) {}
DawgPositionVector *active_dawgs; DawgPositionVector *active_dawgs;
DawgPositionVector *updated_dawgs; DawgPositionVector *updated_dawgs;
PermuterType permuter; PermuterType permuter;
// True if the current position is a valid word end.
bool valid_end;
}; };
class Dict { class Dict {
...@@ -294,7 +295,15 @@ class Dict { ...@@ -294,7 +295,15 @@ class Dict {
/// Initialize Dict class - load dawgs from [lang].traineddata and /// Initialize Dict class - load dawgs from [lang].traineddata and
/// user-specified wordlist and parttern list. /// user-specified wordlist and parttern list.
static DawgCache *GlobalDawgCache(); static DawgCache *GlobalDawgCache();
void Load(DawgCache *dawg_cache); // Sets up ready for a Load or LoadLSTM.
void SetupForLoad(DawgCache *dawg_cache);
// Loads the dawgs needed by Tesseract. Call FinishLoad() after.
void Load(const char *data_file_name, const STRING &lang);
// Loads the dawgs needed by the LSTM model. Call FinishLoad() after.
void LoadLSTM(const char *data_file_name, const STRING &lang);
// Completes the loading process after Load() and/or LoadLSTM().
// Returns false if no dictionaries were loaded.
bool FinishLoad();
void End(); void End();
// Resets the document dictionary analogous to ResetAdaptiveClassifier. // Resets the document dictionary analogous to ResetAdaptiveClassifier.
...@@ -397,9 +406,7 @@ class Dict { ...@@ -397,9 +406,7 @@ class Dict {
} }
inline void SetWildcardID(UNICHAR_ID id) { wildcard_unichar_id_ = id; } inline void SetWildcardID(UNICHAR_ID id) { wildcard_unichar_id_ = id; }
inline UNICHAR_ID WildcardID() const { inline UNICHAR_ID WildcardID() const { return wildcard_unichar_id_; }
return wildcard_unichar_id_;
}
/// Return the number of dawgs in the dawgs_ vector. /// Return the number of dawgs in the dawgs_ vector.
inline int NumDawgs() const { return dawgs_.size(); } inline int NumDawgs() const { return dawgs_.size(); }
/// Return i-th dawg pointer recorded in the dawgs_ vector. /// Return i-th dawg pointer recorded in the dawgs_ vector.
...@@ -436,7 +443,7 @@ class Dict { ...@@ -436,7 +443,7 @@ class Dict {
/// edges were found. /// edges were found.
void ProcessPatternEdges(const Dawg *dawg, const DawgPosition &info, void ProcessPatternEdges(const Dawg *dawg, const DawgPosition &info,
UNICHAR_ID unichar_id, bool word_end, UNICHAR_ID unichar_id, bool word_end,
DawgPositionVector *updated_dawgs, DawgArgs *dawg_args,
PermuterType *current_permuter) const; PermuterType *current_permuter) const;
/// Read/Write/Access special purpose dawgs which contain words /// Read/Write/Access special purpose dawgs which contain words
...@@ -483,6 +490,8 @@ class Dict { ...@@ -483,6 +490,8 @@ class Dict {
inline void SetWordsegRatingAdjustFactor(float f) { inline void SetWordsegRatingAdjustFactor(float f) {
wordseg_rating_adjust_factor_ = f; wordseg_rating_adjust_factor_ = f;
} }
/// Returns true if the language is space-delimited (not CJ, or T).
bool IsSpaceDelimitedLang() const;
private: private:
/** Private member variables. */ /** Private member variables. */
......
AM_CPPFLAGS += \
-I$(top_srcdir)/ccutil -I$(top_srcdir)/cutil -I$(top_srcdir)/ccstruct \
-I$(top_srcdir)/arch -I$(top_srcdir)/viewer -I$(top_srcdir)/classify \
-I$(top_srcdir)/dict -I$(top_srcdir)/lstm
AUTOMAKE_OPTIONS = subdir-objects
SUBDIRS =
AM_CXXFLAGS = -fopenmp
if !NO_TESSDATA_PREFIX
AM_CXXFLAGS += -DTESSDATA_PREFIX=@datadir@/
endif
if VISIBILITY
AM_CXXFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
AM_CPPFLAGS += -DTESS_EXPORTS
endif
include_HEADERS = \
convolve.h ctc.h fullyconnected.h functions.h input.h \
lstm.h lstmrecognizer.h lstmtrainer.h maxpool.h \
networkbuilder.h network.h networkio.h networkscratch.h \
parallel.h plumbing.h recodebeam.h reconfig.h reversed.h \
series.h static_shape.h stridemap.h tfnetwork.h weightmatrix.h
noinst_HEADERS =
if !USING_MULTIPLELIBS
noinst_LTLIBRARIES = libtesseract_lstm.la
else
lib_LTLIBRARIES = libtesseract_lstm.la
libtesseract_lstm_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION)
endif
libtesseract_lstm_la_SOURCES = \
convolve.cpp ctc.cpp fullyconnected.cpp functions.cpp input.cpp \
lstm.cpp lstmrecognizer.cpp lstmtrainer.cpp maxpool.cpp \
networkbuilder.cpp network.cpp networkio.cpp \
parallel.cpp plumbing.cpp recodebeam.cpp reconfig.cpp reversed.cpp \
series.cpp stridemap.cpp tfnetwork.cpp weightmatrix.cpp
///////////////////////////////////////////////////////////////////////
// File: convolve.cpp
// Description: Convolutional layer that stacks the inputs over its rectangle
// and pulls in random data to fill out-of-input inputs.
// Output is therefore same size as its input, but deeper.
// Author: Ray Smith
// Created: Tue Mar 18 16:56:06 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "convolve.h"
#include "networkscratch.h"
#include "serialis.h"
namespace tesseract {
Convolve::Convolve(const STRING& name, int ni, int half_x, int half_y)
: Network(NT_CONVOLVE, name, ni, ni * (2*half_x + 1) * (2*half_y + 1)),
half_x_(half_x), half_y_(half_y) {
}
Convolve::~Convolve() {
}
// Writes to the given file. Returns false in case of error.
bool Convolve::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&half_x_, sizeof(half_x_), 1) != 1) return false;
if (fp->FWrite(&half_y_, sizeof(half_y_), 1) != 1) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool Convolve::DeSerialize(bool swap, TFile* fp) {
if (fp->FRead(&half_x_, sizeof(half_x_), 1) != 1) return false;
if (fp->FRead(&half_y_, sizeof(half_y_), 1) != 1) return false;
if (swap) {
ReverseN(&half_x_, sizeof(half_x_));
ReverseN(&half_y_, sizeof(half_y_));
}
no_ = ni_ * (2*half_x_ + 1) * (2*half_y_ + 1);
return true;
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void Convolve::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
output->Resize(input, no_);
int y_scale = 2 * half_y_ + 1;
StrideMap::Index dest_index(output->stride_map());
do {
// Stack x_scale groups of y_scale * ni_ inputs together.
int t = dest_index.t();
int out_ix = 0;
for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) {
StrideMap::Index x_index(dest_index);
if (!x_index.AddOffset(x, FD_WIDTH)) {
// This x is outside the image.
output->Randomize(t, out_ix, y_scale * ni_, randomizer_);
} else {
int out_iy = out_ix;
for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) {
StrideMap::Index y_index(x_index);
if (!y_index.AddOffset(y, FD_HEIGHT)) {
// This y is outside the image.
output->Randomize(t, out_iy, ni_, randomizer_);
} else {
output->CopyTimeStepGeneral(t, out_iy, ni_, input, y_index.t(), 0);
}
}
}
}
} while (dest_index.Increment());
if (debug) DisplayForward(*output);
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Convolve::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
back_deltas->Resize(fwd_deltas, ni_);
NetworkScratch::IO delta_sum;
delta_sum.ResizeFloat(fwd_deltas, ni_, scratch);
delta_sum->Zero();
int y_scale = 2 * half_y_ + 1;
StrideMap::Index src_index(fwd_deltas.stride_map());
do {
// Stack x_scale groups of y_scale * ni_ inputs together.
int t = src_index.t();
int out_ix = 0;
for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) {
StrideMap::Index x_index(src_index);
if (x_index.AddOffset(x, FD_WIDTH)) {
int out_iy = out_ix;
for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) {
StrideMap::Index y_index(x_index);
if (y_index.AddOffset(y, FD_HEIGHT)) {
fwd_deltas.AddTimeStepPart(t, out_iy, ni_,
delta_sum->f(y_index.t()));
}
}
}
}
} while (src_index.Increment());
back_deltas->CopyWithNormalization(*delta_sum, fwd_deltas);
return true;
}
} // namespace tesseract.
///////////////////////////////////////////////////////////////////////
// File: convolve.h
// Description: Convolutional layer that stacks the inputs over its rectangle
// and pulls in random data to fill out-of-input inputs.
// Output is therefore same size as its input, but deeper.
// Author: Ray Smith
// Created: Tue Mar 18 16:45:34 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_CONVOLVE_H_
#define TESSERACT_LSTM_CONVOLVE_H_
#include "genericvector.h"
#include "matrix.h"
#include "network.h"
namespace tesseract {
// Makes each time-step deeper by stacking inputs over its rectangle. Does not
// affect the size of its input. Achieves this by bringing in random values in
// out-of-input areas.
class Convolve : public Network {
public:
// The area of convolution is 2*half_x + 1 by 2*half_y + 1, forcing it to
// always be odd, so the center is the current pixel.
Convolve(const STRING& name, int ni, int half_x, int half_y);
virtual ~Convolve();
virtual STRING spec() const {
STRING spec;
spec.add_str_int("C", half_x_ * 2 + 1);
spec.add_str_int(",", half_y_ * 2 + 1);
return spec;
}
// Writes to the given file. Returns false in case of error.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
virtual bool DeSerialize(bool swap, TFile* fp);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
protected:
// Serialized data.
inT32 half_x_;
inT32 half_y_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_SUBSAMPLE_H_
///////////////////////////////////////////////////////////////////////
// File: ctc.cpp
// Description: Slightly improved standard CTC to compute the targets.
// Author: Ray Smith
// Created: Wed Jul 13 15:50:06 PDT 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "ctc.h"
#include <memory>
#include "genericvector.h"
#include "host.h"
#include "matrix.h"
#include "networkio.h"
#include "network.h"
#include "scrollview.h"
namespace tesseract {
// Magic constants that keep CTC stable.
// Minimum probability limit for softmax input to ctc_loss.
const float CTC::kMinProb_ = 1e-12;
// Maximum absolute argument to exp().
const double CTC::kMaxExpArg_ = 80.0;
// Minimum probability for total prob in time normalization.
const double CTC::kMinTotalTimeProb_ = 1e-8;
// Minimum probability for total prob in final normalization.
const double CTC::kMinTotalFinalProb_ = 1e-6;
// Builds a target using CTC. Slightly improved as follows:
// Includes normalizations and clipping for stability.
// labels should be pre-padded with nulls everywhere.
// labels can be longer than the time sequence, but the total number of
// essential labels (non-null plus nulls between equal labels) must not exceed
// the number of timesteps in outputs.
// outputs is the output of the network, and should have already been
// normalized with NormalizeProbs.
// On return targets is filled with the computed targets.
// Returns false if there is insufficient time for the labels.
/* static */
bool CTC::ComputeCTCTargets(const GenericVector<int>& labels, int null_char,
const GENERIC_2D_ARRAY<float>& outputs,
NetworkIO* targets) {
std::unique_ptr<CTC> ctc(new CTC(labels, null_char, outputs));
if (!ctc->ComputeLabelLimits()) {
return false; // Not enough time.
}
// Generate simple targets purely from the truth labels by spreading them
// evenly over time.
GENERIC_2D_ARRAY<float> simple_targets;
ctc->ComputeSimpleTargets(&simple_targets);
// Add the simple targets as a starter bias to the network outputs.
float bias_fraction = ctc->CalculateBiasFraction();
simple_targets *= bias_fraction;
ctc->outputs_ += simple_targets;
NormalizeProbs(&ctc->outputs_);
// Run regular CTC on the biased outputs.
// Run forward and backward
GENERIC_2D_ARRAY<double> log_alphas, log_betas;
ctc->Forward(&log_alphas);
ctc->Backward(&log_betas);
// Normalize and come out of log space with a clipped softmax over time.
log_alphas += log_betas;
ctc->NormalizeSequence(&log_alphas);
ctc->LabelsToClasses(log_alphas, targets);
NormalizeProbs(targets);
return true;
}
CTC::CTC(const GenericVector<int>& labels, int null_char,
const GENERIC_2D_ARRAY<float>& outputs)
: labels_(labels), outputs_(outputs), null_char_(null_char) {
num_timesteps_ = outputs.dim1();
num_classes_ = outputs.dim2();
num_labels_ = labels_.size();
}
// Computes vectors of min and max label index for each timestep, based on
// whether skippability of nulls makes it possible to complete a valid path.
bool CTC::ComputeLabelLimits() {
min_labels_.init_to_size(num_timesteps_, 0);
max_labels_.init_to_size(num_timesteps_, 0);
int min_u = num_labels_ - 1;
if (labels_[min_u] == null_char_) --min_u;
for (int t = num_timesteps_ - 1; t >= 0; --t) {
min_labels_[t] = min_u;
if (min_u > 0) {
--min_u;
if (labels_[min_u] == null_char_ && min_u > 0 &&
labels_[min_u + 1] != labels_[min_u - 1]) {
--min_u;
}
}
}
int max_u = labels_[0] == null_char_;
for (int t = 0; t < num_timesteps_; ++t) {
max_labels_[t] = max_u;
if (max_labels_[t] < min_labels_[t]) return false; // Not enough room.
if (max_u + 1 < num_labels_) {
++max_u;
if (labels_[max_u] == null_char_ && max_u + 1 < num_labels_ &&
labels_[max_u + 1] != labels_[max_u - 1]) {
++max_u;
}
}
}
return true;
}
// Computes targets based purely on the labels by spreading the labels evenly
// over the available timesteps.
void CTC::ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const {
// Initialize all targets to zero.
targets->Resize(num_timesteps_, num_classes_, 0.0f);
GenericVector<float> half_widths;
GenericVector<int> means;
ComputeWidthsAndMeans(&half_widths, &means);
for (int l = 0; l < num_labels_; ++l) {
int label = labels_[l];
float left_half_width = half_widths[l];
float right_half_width = left_half_width;
int mean = means[l];
if (label == null_char_) {
if (!NeededNull(l)) {
if ((l > 0 && mean == means[l - 1]) ||
(l + 1 < num_labels_ && mean == means[l + 1])) {
continue; // Drop overlapping null.
}
}
// Make sure that no space is left unoccupied and that non-nulls always
// peak at 1 by stretching nulls to meet their neighbors.
if (l > 0) left_half_width = mean - means[l - 1];
if (l + 1 < num_labels_) right_half_width = means[l + 1] - mean;
}
if (mean >= 0 && mean < num_timesteps_) targets->put(mean, label, 1.0f);
for (int offset = 1; offset < left_half_width && mean >= offset; ++offset) {
float prob = 1.0f - offset / left_half_width;
if (mean - offset < num_timesteps_ &&
prob > targets->get(mean - offset, label)) {
targets->put(mean - offset, label, prob);
}
}
for (int offset = 1;
offset < right_half_width && mean + offset < num_timesteps_;
++offset) {
float prob = 1.0f - offset / right_half_width;
if (mean + offset >= 0 && prob > targets->get(mean + offset, label)) {
targets->put(mean + offset, label, prob);
}
}
}
}
// Computes mean positions and half widths of the simple targets by spreading
// the labels evenly over the available timesteps.
void CTC::ComputeWidthsAndMeans(GenericVector<float>* half_widths,
GenericVector<int>* means) const {
// Count the number of labels of each type, in regexp terms, counts plus
// (non-null or necessary null, which must occur at least once) and star
// (optional null).
int num_plus = 0, num_star = 0;
for (int i = 0; i < num_labels_; ++i) {
if (labels_[i] != null_char_ || NeededNull(i))
++num_plus;
else
++num_star;
}
// Compute the size for each type. If there is enough space for everything
// to have size>=1, then all are equal, otherwise plus_size=1 and star gets
// whatever is left-over.
float plus_size = 1.0f, star_size = 0.0f;
float total_floating = num_plus + num_star;
if (total_floating <= num_timesteps_) {
plus_size = star_size = num_timesteps_ / total_floating;
} else if (num_star > 0) {
star_size = static_cast<float>(num_timesteps_ - num_plus) / num_star;
}
// Set the width and compute the mean of each.
float mean_pos = 0.0f;
for (int i = 0; i < num_labels_; ++i) {
float half_width;
if (labels_[i] != null_char_ || NeededNull(i)) {
half_width = plus_size / 2.0f;
} else {
half_width = star_size / 2.0f;
}
mean_pos += half_width;
means->push_back(static_cast<int>(mean_pos));
mean_pos += half_width;
half_widths->push_back(half_width);
}
}
// Helper returns the index of the highest probability label at timestep t.
static int BestLabel(const GENERIC_2D_ARRAY<float>& outputs, int t) {
int result = 0;
int num_classes = outputs.dim2();
const float* outputs_t = outputs[t];
for (int c = 1; c < num_classes; ++c) {
if (outputs_t[c] > outputs_t[result]) result = c;
}
return result;
}
// Calculates and returns a suitable fraction of the simple targets to add
// to the network outputs.
float CTC::CalculateBiasFraction() {
// Compute output labels via basic decoding.
GenericVector<int> output_labels;
for (int t = 0; t < num_timesteps_; ++t) {
int label = BestLabel(outputs_, t);
while (t + 1 < num_timesteps_ && BestLabel(outputs_, t + 1) == label) ++t;
if (label != null_char_) output_labels.push_back(label);
}
// Simple bag of labels error calculation.
GenericVector<int> truth_counts(num_classes_, 0);
GenericVector<int> output_counts(num_classes_, 0);
for (int l = 0; l < num_labels_; ++l) {
++truth_counts[labels_[l]];
}
for (int l = 0; l < output_labels.size(); ++l) {
++output_counts[output_labels[l]];
}
// Count the number of true and false positive non-nulls and truth labels.
int true_pos = 0, false_pos = 0, total_labels = 0;
for (int c = 0; c < num_classes_; ++c) {
if (c == null_char_) continue;
int truth_count = truth_counts[c];
int ocr_count = output_counts[c];
if (truth_count > 0) {
total_labels += truth_count;
if (ocr_count > truth_count) {
true_pos += truth_count;
false_pos += ocr_count - truth_count;
} else {
true_pos += ocr_count;
}
}
// We don't need to count classes that don't exist in the truth as
// false positives, because they don't affect CTC at all.
}
if (total_labels == 0) return 0.0f;
return exp(MAX(true_pos - false_pos, 1) * log(kMinProb_) / total_labels);
}
// Given ln(x) and ln(y), returns ln(x + y), using:
// ln(x + y) = ln(y) + ln(1 + exp(ln(y) - ln(x)), ensuring that ln(x) is the
// bigger number to maximize precision.
static double LogSumExp(double ln_x, double ln_y) {
if (ln_x >= ln_y) {
return ln_x + log1p(exp(ln_y - ln_x));
} else {
return ln_y + log1p(exp(ln_x - ln_y));
}
}
// Runs the forward CTC pass, filling in log_probs.
void CTC::Forward(GENERIC_2D_ARRAY<double>* log_probs) const {
log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
log_probs->put(0, 0, log(outputs_(0, labels_[0])));
if (labels_[0] == null_char_)
log_probs->put(0, 1, log(outputs_(0, labels_[1])));
for (int t = 1; t < num_timesteps_; ++t) {
const float* outputs_t = outputs_[t];
for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
// Continuing the same label.
double log_sum = log_probs->get(t - 1, u);
// Change from previous label.
if (u > 0) {
log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 1));
}
// Skip the null if allowed.
if (u >= 2 && labels_[u - 1] == null_char_ &&
labels_[u] != labels_[u - 2]) {
log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 2));
}
// Add in the log prob of the current label.
double label_prob = outputs_t[labels_[u]];
log_sum += log(label_prob);
log_probs->put(t, u, log_sum);
}
}
}
// Runs the backward CTC pass, filling in log_probs.
void CTC::Backward(GENERIC_2D_ARRAY<double>* log_probs) const {
log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
log_probs->put(num_timesteps_ - 1, num_labels_ - 1, 0.0);
if (labels_[num_labels_ - 1] == null_char_)
log_probs->put(num_timesteps_ - 1, num_labels_ - 2, 0.0);
for (int t = num_timesteps_ - 2; t >= 0; --t) {
const float* outputs_tp1 = outputs_[t + 1];
for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
// Continuing the same label.
double log_sum = log_probs->get(t + 1, u) + log(outputs_tp1[labels_[u]]);
// Change from previous label.
if (u + 1 < num_labels_) {
double prev_prob = outputs_tp1[labels_[u + 1]];
log_sum =
LogSumExp(log_sum, log_probs->get(t + 1, u + 1) + log(prev_prob));
}
// Skip the null if allowed.
if (u + 2 < num_labels_ && labels_[u + 1] == null_char_ &&
labels_[u] != labels_[u + 2]) {
double skip_prob = outputs_tp1[labels_[u + 2]];
log_sum =
LogSumExp(log_sum, log_probs->get(t + 1, u + 2) + log(skip_prob));
}
log_probs->put(t, u, log_sum);
}
}
}
// Normalizes and brings probs out of log space with a softmax over time.
void CTC::NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const {
double max_logprob = probs->Max();
for (int u = 0; u < num_labels_; ++u) {
double total = 0.0;
for (int t = 0; t < num_timesteps_; ++t) {
// Separate impossible path from unlikely probs.
double prob = probs->get(t, u);
if (prob > -MAX_FLOAT32)
prob = ClippedExp(prob - max_logprob);
else
prob = 0.0;
total += prob;
probs->put(t, u, prob);
}
// Note that although this is a probability distribution over time and
// therefore should sum to 1, it is important to allow some labels to be
// all zero, (or at least tiny) as it is necessary to skip some blanks.
if (total < kMinTotalTimeProb_) total = kMinTotalTimeProb_;
for (int t = 0; t < num_timesteps_; ++t)
probs->put(t, u, probs->get(t, u) / total);
}
}
// For each timestep computes the max prob for each class over all
// instances of the class in the labels_, and sets the targets to
// the max observed prob.
void CTC::LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
NetworkIO* targets) const {
// For each timestep compute the max prob for each class over all
// instances of the class in the labels_.
GenericVector<double> class_probs;
for (int t = 0; t < num_timesteps_; ++t) {
float* targets_t = targets->f(t);
class_probs.init_to_size(num_classes_, 0.0);
for (int u = 0; u < num_labels_; ++u) {
double prob = probs(t, u);
// Note that although Graves specifies sum over all labels of the same
// class, we need to allow skipped blanks to go to zero, so they don't
// interfere with the non-blanks, so max is better than sum.
if (prob > class_probs[labels_[u]]) class_probs[labels_[u]] = prob;
// class_probs[labels_[u]] += prob;
}
int best_class = 0;
for (int c = 0; c < num_classes_; ++c) {
targets_t[c] = class_probs[c];
if (class_probs[c] > class_probs[best_class]) best_class = c;
}
}
}
// Normalizes the probabilities such that no target has a prob below min_prob,
// and, provided that the initial total is at least min_total_prob, then all
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
// probability is thus 1 - (num_classes-1)*min_prob.
/* static */
void CTC::NormalizeProbs(GENERIC_2D_ARRAY<float>* probs) {
int num_timesteps = probs->dim1();
int num_classes = probs->dim2();
for (int t = 0; t < num_timesteps; ++t) {
float* probs_t = (*probs)[t];
// Compute the total and clip that to prevent amplification of noise.
double total = 0.0;
for (int c = 0; c < num_classes; ++c) total += probs_t[c];
if (total < kMinTotalFinalProb_) total = kMinTotalFinalProb_;
// Compute the increased total as a result of clipping.
double increment = 0.0;
for (int c = 0; c < num_classes; ++c) {
double prob = probs_t[c] / total;
if (prob < kMinProb_) increment += kMinProb_ - prob;
}
// Now normalize with clipping. Any additional clipping is negligible.
total += increment;
for (int c = 0; c < num_classes; ++c) {
float prob = probs_t[c] / total;
probs_t[c] = MAX(prob, kMinProb_);
}
}
}
// Returns true if the label at index is a needed null.
bool CTC::NeededNull(int index) const {
return labels_[index] == null_char_ && index > 0 && index + 1 < num_labels_ &&
labels_[index + 1] == labels_[index - 1];
}
} // namespace tesseract
///////////////////////////////////////////////////////////////////////
// File: ctc.h
// Description: Slightly improved standard CTC to compute the targets.
// Author: Ray Smith
// Created: Wed Jul 13 15:17:06 PDT 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_CTC_H_
#define TESSERACT_LSTM_CTC_H_
#include "genericvector.h"
#include "network.h"
#include "networkio.h"
#include "scrollview.h"
namespace tesseract {
// Class to encapsulate CTC and simple target generation.
class CTC {
public:
// Normalizes the probabilities such that no target has a prob below min_prob,
// and, provided that the initial total is at least min_total_prob, then all
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
// probability is thus 1 - (num_classes-1)*min_prob.
static void NormalizeProbs(NetworkIO* probs) {
NormalizeProbs(probs->mutable_float_array());
}
// Builds a target using CTC. Slightly improved as follows:
// Includes normalizations and clipping for stability.
// labels should be pre-padded with nulls wherever desired, but they don't
// have to be between all labels. Allows for multi-label codes with no
// nulls between.
// labels can be longer than the time sequence, but the total number of
// essential labels (non-null plus nulls between equal labels) must not exceed
// the number of timesteps in outputs.
// outputs is the output of the network, and should have already been
// normalized with NormalizeProbs.
// On return targets is filled with the computed targets.
// Returns false if there is insufficient time for the labels.
static bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
int null_char,
const GENERIC_2D_ARRAY<float>& outputs,
NetworkIO* targets);
private:
// Constructor is private as the instance only holds information specific to
// the current labels, outputs etc, and is built by the static function.
CTC(const GenericVector<int>& labels, int null_char,
const GENERIC_2D_ARRAY<float>& outputs);
// Computes vectors of min and max label index for each timestep, based on
// whether skippability of nulls makes it possible to complete a valid path.
bool ComputeLabelLimits();
// Computes targets based purely on the labels by spreading the labels evenly
// over the available timesteps.
void ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const;
// Computes mean positions and half widths of the simple targets by spreading
// the labels even over the available timesteps.
void ComputeWidthsAndMeans(GenericVector<float>* half_widths,
GenericVector<int>* means) const;
// Calculates and returns a suitable fraction of the simple targets to add
// to the network outputs.
float CalculateBiasFraction();
// Runs the forward CTC pass, filling in log_probs.
void Forward(GENERIC_2D_ARRAY<double>* log_probs) const;
// Runs the backward CTC pass, filling in log_probs.
void Backward(GENERIC_2D_ARRAY<double>* log_probs) const;
// Normalizes and brings probs out of log space with a softmax over time.
void NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const;
// For each timestep computes the max prob for each class over all
// instances of the class in the labels_, and sets the targets to
// the max observed prob.
void LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
NetworkIO* targets) const;
// Normalizes the probabilities such that no target has a prob below min_prob,
// and, provided that the initial total is at least min_total_prob, then all
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
// probability is thus 1 - (num_classes-1)*min_prob.
static void NormalizeProbs(GENERIC_2D_ARRAY<float>* probs);
// Returns true if the label at index is a needed null.
bool NeededNull(int index) const;
// Returns exp(clipped(x)), clipping x to a reasonable range to prevent over/
// underflow.
static double ClippedExp(double x) {
if (x < -kMaxExpArg_) return exp(-kMaxExpArg_);
if (x > kMaxExpArg_) return exp(kMaxExpArg_);
return exp(x);
}
// Minimum probability limit for softmax input to ctc_loss.
static const float kMinProb_;
// Maximum absolute argument to exp().
static const double kMaxExpArg_;
// Minimum probability for total prob in time normalization.
static const double kMinTotalTimeProb_;
// Minimum probability for total prob in final normalization.
static const double kMinTotalFinalProb_;
// The truth label indices that are to be matched to outputs_.
const GenericVector<int>& labels_;
// The network outputs.
GENERIC_2D_ARRAY<float> outputs_;
// The null or "blank" label.
int null_char_;
// Number of timesteps in outputs_.
int num_timesteps_;
// Number of classes in outputs_.
int num_classes_;
// Number of labels in labels_.
int num_labels_;
// Min and max valid label indices for each timestep.
GenericVector<int> min_labels_;
GenericVector<int> max_labels_;
};
} // namespace tesseract
#endif // TESSERACT_LSTM_CTC_H_
此差异已折叠。
///////////////////////////////////////////////////////////////////////
// File: fullyconnected.h
// Description: Simple feed-forward layer with various non-linearities.
// Author: Ray Smith
// Created: Wed Feb 26 14:46:06 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_FULLYCONNECTED_H_
#define TESSERACT_LSTM_FULLYCONNECTED_H_
#include "network.h"
#include "networkscratch.h"
namespace tesseract {
// C++ Implementation of the Softmax (output) class from lstm.py.
class FullyConnected : public Network {
public:
FullyConnected(const STRING& name, int ni, int no, NetworkType type);
virtual ~FullyConnected();
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
virtual STRING spec() const {
STRING spec;
if (type_ == NT_TANH)
spec.add_str_int("Ft", no_);
else if (type_ == NT_LOGISTIC)
spec.add_str_int("Fs", no_);
else if (type_ == NT_RELU)
spec.add_str_int("Fr", no_);
else if (type_ == NT_LINEAR)
spec.add_str_int("Fl", no_);
else if (type_ == NT_POSCLIP)
spec.add_str_int("Fp", no_);
else if (type_ == NT_SYMCLIP)
spec.add_str_int("Fs", no_);
else if (type_ == NT_SOFTMAX)
spec.add_str_int("Fc", no_);
else
spec.add_str_int("Fm", no_);
return spec;
}
// Changes the type to the given type. Used to commute a softmax to a
// non-output type for adding on other networks.
void ChangeType(NetworkType type) {
type_ = type;
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
virtual int InitWeights(float range, TRand* randomizer);
// Converts a float network to an int network.
virtual void ConvertToInt();
// Provides debug output on the weights.
virtual void DebugWeights();
// Writes to the given file. Returns false in case of error.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
virtual bool DeSerialize(bool swap, TFile* fp);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Components of Forward so FullyConnected can be reused inside LSTM.
void SetupForward(const NetworkIO& input,
const TransposedArray* input_transpose);
void ForwardTimeStep(const double* d_input, const inT8* i_input, int t,
double* output_line);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
// Components of Backward so FullyConnected can be reused inside LSTM.
void BackwardTimeStep(const NetworkIO& fwd_deltas, int t, double* curr_errors,
TransposedArray* errors_t, double* backprop);
void FinishBackward(const TransposedArray& errors_t);
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
virtual void Update(float learning_rate, float momentum, int num_samples);
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
virtual void CountAlternators(const Network& other, double* same,
double* changed) const;
protected:
// Weight arrays of size [no, ni + 1].
WeightMatrix weights_;
// Transposed copy of input used during training of size [ni, width].
TransposedArray source_t_;
// Pointer to transposed input stored elsewhere. If not null, this is used
// in preference to calculating the transpose and storing it in source_t_.
const TransposedArray* external_source_;
// Activations from forward pass of size [width, no].
NetworkIO acts_;
// Memory of the integer mode input to forward as softmax always outputs
// float, so the information is otherwise lost.
bool int_mode_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_FULLYCONNECTED_H_
///////////////////////////////////////////////////////////////////////
// File: functions.cpp
// Description: Static initialize-on-first-use non-linearity functions.
// Author: Ray Smith
// Created: Tue Jul 17 14:02:59 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "functions.h"
namespace tesseract {
double TanhTable[kTableSize];
double LogisticTable[kTableSize];
} // namespace tesseract.
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
///////////////////////////////////////////////////////////////////////
// File: maxpool.h
// Description: Standard Max-Pooling layer.
// Author: Ray Smith
// Created: Tue Mar 18 16:28:18 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "maxpool.h"
#include "tprintf.h"
namespace tesseract {
Maxpool::Maxpool(const STRING& name, int ni, int x_scale, int y_scale)
: Reconfig(name, ni, x_scale, y_scale) {
type_ = NT_MAXPOOL;
no_ = ni;
}
Maxpool::~Maxpool() {
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool Maxpool::DeSerialize(bool swap, TFile* fp) {
bool result = Reconfig::DeSerialize(swap, fp);
no_ = ni_;
return result;
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void Maxpool::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
output->ResizeScaled(input, x_scale_, y_scale_, no_);
maxes_.ResizeNoInit(output->Width(), ni_);
back_map_ = input.stride_map();
StrideMap::Index dest_index(output->stride_map());
do {
int out_t = dest_index.t();
StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH),
dest_index.index(FD_HEIGHT) * y_scale_,
dest_index.index(FD_WIDTH) * x_scale_);
// Find the max input out of x_scale_ groups of y_scale_ inputs.
// Do it independently for each input dimension.
int* max_line = maxes_[out_t];
int in_t = src_index.t();
output->CopyTimeStepFrom(out_t, input, in_t);
for (int i = 0; i < ni_; ++i) {
max_line[i] = in_t;
}
for (int x = 0; x < x_scale_; ++x) {
for (int y = 0; y < y_scale_; ++y) {
StrideMap::Index src_xy(src_index);
if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) {
output->MaxpoolTimeStep(out_t, input, src_xy.t(), max_line);
}
}
}
} while (dest_index.Increment());
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Maxpool::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_);
back_deltas->MaxpoolBackward(fwd_deltas, maxes_);
return true;
}
} // namespace tesseract.
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -850,7 +850,8 @@ void BaselineDetect::ComputeBaselineSplinesAndXheights(const ICOORD& page_tr, ...@@ -850,7 +850,8 @@ void BaselineDetect::ComputeBaselineSplinesAndXheights(const ICOORD& page_tr,
Pix* pix_spline = pix_debug_ ? pixConvertTo32(pix_debug_) : NULL; Pix* pix_spline = pix_debug_ ? pixConvertTo32(pix_debug_) : NULL;
for (int i = 0; i < blocks_.size(); ++i) { for (int i = 0; i < blocks_.size(); ++i) {
BaselineBlock* bl_block = blocks_[i]; BaselineBlock* bl_block = blocks_[i];
bl_block->PrepareForSplineFitting(page_tr, remove_noise); if (enable_splines)
bl_block->PrepareForSplineFitting(page_tr, remove_noise);
bl_block->FitBaselineSplines(enable_splines, show_final_rows, textord); bl_block->FitBaselineSplines(enable_splines, show_final_rows, textord);
if (pix_spline) { if (pix_spline) {
bl_block->DrawPixSpline(pix_spline); bl_block->DrawPixSpline(pix_spline);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册