提交 3ec11bd3 编写于 作者: R Ray Smith

Deleted some dead LSTM code, making everything use the recoder

上级 aee910a7
......@@ -21,6 +21,7 @@
#define TESSERACT_API_APITYPES_H_
#include "publictypes.h"
#include "version.h"
// The types used by the API and Page/ResultIterator can be found in:
// ccstruct/publictypes.h
......
......@@ -20,11 +20,6 @@
#ifndef TESSERACT_API_BASEAPI_H_
#define TESSERACT_API_BASEAPI_H_
#define TESSERACT_VERSION_STR "4.00.00alpha"
#define TESSERACT_VERSION 0x040000
#define MAKE_VERSION(major, minor, patch) (((major) << 16) | ((minor) << 8) | \
(patch))
#include <stdio.h>
// To avoid collision with other typenames include the ABSOLUTE MINIMUM
// complexity of includes here. Use forward declarations wherever possible
......
......@@ -31,9 +31,6 @@
namespace tesseract {
// Arbitarary penalty for non-dictionary words.
// TODO(rays) How to learn this?
const float kNonDictionaryPenalty = 5.0f;
// Scale factor to make certainty more comparable to Tesseract.
const float kCertaintyScale = 7.0f;
// Worst acceptable certainty for a dictionary word.
......@@ -241,8 +238,7 @@ void Tesseract::LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word,
if (im_data == NULL) return;
lstm_recognizer_->RecognizeLine(*im_data, true, classify_debug_level > 0,
kWorstDictCertainty / kCertaintyScale,
lstm_use_matrix, &unicharset, word_box, 2.0,
false, words);
word_box, words);
delete im_data;
SearchWords(words);
}
......@@ -268,17 +264,6 @@ void Tesseract::SearchWords(PointerVector<WERD_RES>* words) {
}
for (int w = 0; w < words->size(); ++w) {
WERD_RES* word = (*words)[w];
if (word->best_choice == NULL) {
// If we are using the beam search, the unicharset had better match!
word->SetupWordScript(unicharset);
WordSearch(word);
} else if (word->best_choice->unicharset() == &unicharset &&
!lstm_recognizer_->IsRecoding()) {
// We set up the word without using the dictionary, so set the permuter
// now, but we can only do it because the unicharsets match.
word->best_choice->set_permuter(
getDict().valid_word(*word->best_choice, true));
}
if (word->best_choice == NULL) {
// It is a dud.
word->SetupFake(lstm_recognizer_->GetUnicharset());
......@@ -297,10 +282,6 @@ void Tesseract::SearchWords(PointerVector<WERD_RES>* words) {
float word_certainty = MIN(word->space_certainty,
word->best_choice->certainty());
word_certainty *= kCertaintyScale;
// Arbitrary ding factor for non-dictionary words.
if (!lstm_recognizer_->IsRecoding() &&
!Dict::valid_word_permuter(word->best_choice->permuter(), true))
word_certainty -= kNonDictionaryPenalty;
if (getDict().stopper_debug_level >= 1) {
tprintf("Best choice certainty=%g, space=%g, scaled=%g, final=%g\n",
word->best_choice->certainty(), word->space_certainty,
......
......@@ -238,6 +238,11 @@ void UnicharCompress::SetupPassThrough(const UNICHARSET& unicharset) {
code.Set(0, u);
codes.push_back(code);
}
if (!unicharset.has_special_codes()) {
RecodedCharID code;
code.Set(0, unicharset.size());
codes.push_back(code);
}
SetupDirect(codes);
}
......
......@@ -115,6 +115,9 @@ bool LSTMRecognizer::DeSerialize(TFile* fp) {
tprintf("Space was garbled in recoding!!\n");
return false;
}
} else {
recoder_.SetupPassThrough(GetUnicharset());
training_flags_ |= TF_COMPRESS_UNICHARSET;
}
network_->SetRandomizer(&randomizer_);
network_->CacheXScaleFactor(network_->XScaleFactor());
......@@ -145,91 +148,21 @@ bool LSTMRecognizer::LoadDictionary(const char* lang, TessdataManager* mgr) {
// ratings matrix and matching box_word for each WERD_RES in the output.
void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
bool debug, double worst_dict_cert,
bool use_alternates,
const UNICHARSET* target_unicharset,
const TBOX& line_box, float score_ratio,
bool one_word,
const TBOX& line_box,
PointerVector<WERD_RES>* words) {
NetworkIO outputs;
float label_threshold = use_alternates ? 0.75f : 0.0f;
float scale_factor;
NetworkIO inputs;
if (!RecognizeLine(image_data, invert, debug, false, label_threshold,
&scale_factor, &inputs, &outputs))
if (!RecognizeLine(image_data, invert, debug, false, &scale_factor, &inputs,
&outputs))
return;
if (IsRecoding()) {
if (search_ == NULL) {
search_ =
new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
}
search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, NULL);
search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
&GetUnicharset(), words);
} else {
GenericVector<int> label_coords;
GenericVector<int> labels;
LabelsFromOutputs(outputs, label_threshold, &labels, &label_coords);
WordsFromOutputs(outputs, labels, label_coords, line_box, debug,
use_alternates, one_word, score_ratio, scale_factor,
target_unicharset, words);
}
}
// Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
// corresponding to the network output in outputs, labels, label_coords.
// one_word generates a single word output, that may include spaces inside.
// use_alternates generates alternative BLOB_CHOICEs and segmentation paths.
// If not NULL, we attempt to translate the output to target_unicharset, but do
// not guarantee success, due to mismatches. In that case the output words are
// marked with our UNICHARSET, not the caller's.
void LSTMRecognizer::WordsFromOutputs(
const NetworkIO& outputs, const GenericVector<int>& labels,
const GenericVector<int> label_coords, const TBOX& line_box, bool debug,
bool use_alternates, bool one_word, float score_ratio, float scale_factor,
const UNICHARSET* target_unicharset, PointerVector<WERD_RES>* words) {
// Convert labels to unichar-ids.
int word_end = 0;
float prev_space_cert = 0.0f;
for (int i = 0; i < labels.size(); i = word_end) {
word_end = i + 1;
if (labels[i] == null_char_ || labels[i] == UNICHAR_SPACE) {
continue;
}
float space_cert = 0.0f;
if (one_word) {
word_end = labels.size();
} else {
// Find the end of the word at the first null_char_ that leads to the
// first UNICHAR_SPACE.
while (word_end < labels.size() && labels[word_end] != UNICHAR_SPACE)
++word_end;
if (word_end < labels.size()) {
float rating;
outputs.ScoresOverRange(label_coords[word_end],
label_coords[word_end] + 1, UNICHAR_SPACE,
null_char_, &rating, &space_cert);
}
while (word_end > i && labels[word_end - 1] == null_char_) --word_end;
}
ASSERT_HOST(word_end > i);
// Create a WERD_RES for the output word.
if (debug)
tprintf("Creating word from outputs over [%d,%d)\n", i, word_end);
WERD_RES* word =
WordFromOutput(line_box, outputs, i, word_end, score_ratio,
MIN(prev_space_cert, space_cert), debug,
use_alternates && !SimpleTextOutput(), target_unicharset,
labels, label_coords, scale_factor);
if (word == NULL && target_unicharset != NULL) {
// Unicharset translation failed - use decoder_ instead, and disable
// the segmentation search on output, as it won't understand the encoding.
word = WordFromOutput(line_box, outputs, i, word_end, score_ratio,
MIN(prev_space_cert, space_cert), debug, false,
NULL, labels, label_coords, scale_factor);
}
prev_space_cert = space_cert;
words->push_back(word);
if (search_ == NULL) {
search_ =
new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
}
search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, NULL);
search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
&GetUnicharset(), words);
}
// Helper computes min and mean best results in the output.
......@@ -251,12 +184,10 @@ void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output,
// Recognizes the image_data, returning the labels,
// scores, and corresponding pairs of start, end x-coords in coords.
// If label_threshold is positive, uses it for making the labels, otherwise
// uses standard ctc.
bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
bool debug, bool re_invert,
float label_threshold, float* scale_factor,
NetworkIO* inputs, NetworkIO* outputs) {
float* scale_factor, NetworkIO* inputs,
NetworkIO* outputs) {
// Maximum width of image to train on.
const int kMaxImageWidth = 2560;
// This ensures consistent recognition results.
......@@ -312,135 +243,13 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
pixDestroy(&pix);
if (debug) {
GenericVector<int> labels, coords;
LabelsFromOutputs(*outputs, label_threshold, &labels, &coords);
LabelsFromOutputs(*outputs, &labels, &coords);
DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
DebugActivationPath(*outputs, labels, coords);
}
return true;
}
// Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
// line_box should be the bounding box of the line image in the main image,
// outputs the output of the network,
// [word_start, word_end) the interval over which to convert,
// score_ratio for choosing alternate classifier choices,
// use_alternates to control generation of alternative segmentations,
// labels, label_coords, scale_factor from RecognizeLine above.
// If target_unicharset is not NULL, attempts to translate the internal
// unichar_ids to the target_unicharset, but falls back to untranslated ids
// if the translation should fail.
WERD_RES* LSTMRecognizer::WordFromOutput(
const TBOX& line_box, const NetworkIO& outputs, int word_start,
int word_end, float score_ratio, float space_certainty, bool debug,
bool use_alternates, const UNICHARSET* target_unicharset,
const GenericVector<int>& labels, const GenericVector<int>& label_coords,
float scale_factor) {
WERD_RES* word_res = InitializeWord(
line_box, word_start, word_end, space_certainty, use_alternates,
target_unicharset, labels, label_coords, scale_factor);
int max_blob_run = word_res->ratings->bandwidth();
for (int width = 1; width <= max_blob_run; ++width) {
int col = 0;
for (int i = word_start; i + width <= word_end; ++i) {
if (labels[i] != null_char_) {
// Starting at i, use width labels, but stop at the next null_char_.
// This forms all combinations of blobs between regions of null_char_.
int j = i + 1;
while (j - i < width && labels[j] != null_char_) ++j;
if (j - i == width) {
// Make the blob choices.
int end_coord = label_coords[j];
if (j < word_end && labels[j] == null_char_)
end_coord = label_coords[j + 1];
BLOB_CHOICE_LIST* choices = GetBlobChoices(
col, col + width - 1, debug, outputs, target_unicharset,
label_coords[i], end_coord, score_ratio);
if (choices == NULL) {
delete word_res;
return NULL;
}
word_res->ratings->put(col, col + width - 1, choices);
}
++col;
}
}
}
if (use_alternates) {
// Merge adjacent single results over null_char boundaries.
int col = 0;
for (int i = word_start; i + 2 < word_end; ++i) {
if (labels[i] != null_char_ && labels[i + 1] == null_char_ &&
labels[i + 2] != null_char_ &&
(i == word_start || labels[i - 1] == null_char_) &&
(i + 3 == word_end || labels[i + 3] == null_char_)) {
int end_coord = label_coords[i + 3];
if (i + 3 < word_end && labels[i + 3] == null_char_)
end_coord = label_coords[i + 4];
BLOB_CHOICE_LIST* choices =
GetBlobChoices(col, col + 1, debug, outputs, target_unicharset,
label_coords[i], end_coord, score_ratio);
if (choices == NULL) {
delete word_res;
return NULL;
}
word_res->ratings->put(col, col + 1, choices);
}
if (labels[i] != null_char_) ++col;
}
} else {
word_res->FakeWordFromRatings(TOP_CHOICE_PERM);
}
return word_res;
}
// Sets up a word with the ratings matrix and fake blobs with boxes in the
// right places.
WERD_RES* LSTMRecognizer::InitializeWord(const TBOX& line_box, int word_start,
int word_end, float space_certainty,
bool use_alternates,
const UNICHARSET* target_unicharset,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
float scale_factor) {
// Make a fake blob for each non-zero label.
C_BLOB_LIST blobs;
C_BLOB_IT b_it(&blobs);
// num_blobs is the length of the diagonal of the ratings matrix.
int num_blobs = 0;
// max_blob_run is the diagonal width of the ratings matrix
int max_blob_run = 0;
int blob_run = 0;
for (int i = word_start; i < word_end; ++i) {
if (IsRecoding() && !recoder_.IsValidFirstCode(labels[i])) continue;
if (labels[i] != null_char_) {
// Make a fake blob.
TBOX box(label_coords[i], 0, label_coords[i + 1], line_box.height());
box.scale(scale_factor);
box.move(ICOORD(line_box.left(), line_box.bottom()));
box.set_top(line_box.top());
b_it.add_after_then_move(C_BLOB::FakeBlob(box));
++num_blobs;
++blob_run;
}
if (labels[i] == null_char_ || i + 1 == word_end) {
if (blob_run > max_blob_run)
max_blob_run = blob_run;
}
}
if (!use_alternates) max_blob_run = 1;
ASSERT_HOST(label_coords.size() >= word_end);
// Make a fake word from the blobs.
WERD* word = new WERD(&blobs, word_start > 1 ? 1 : 0, NULL);
// Make a WERD_RES from the word.
WERD_RES* word_res = new WERD_RES(word);
word_res->uch_set =
target_unicharset != NULL ? target_unicharset : &GetUnicharset();
word_res->combination = true; // Give it ownership of the word.
word_res->space_certainty = space_certainty;
word_res->ratings = new MATRIX(num_blobs, max_blob_run);
return word_res;
}
// Converts an array of labels to utf-8, whether or not the labels are
// augmented with character boundaries.
STRING LSTMRecognizer::DecodeLabels(const GenericVector<int>& labels) {
......@@ -569,83 +378,14 @@ static bool NullIsBest(const NetworkIO& output, float null_thr,
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs,
GenericVector<int>* labels,
GenericVector<int>* xcoords) {
if (SimpleTextOutput()) {
LabelsViaSimpleText(outputs, labels, xcoords);
} else if (IsRecoding()) {
LabelsViaReEncode(outputs, labels, xcoords);
} else if (null_thr <= 0.0) {
LabelsViaCTC(outputs, labels, xcoords);
} else {
LabelsViaThreshold(outputs, null_thr, labels, xcoords);
}
}
// Converts the network output to a sequence of labels, using a threshold
// on the null_char_ to determine character boundaries. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The label output is the one with the highest score in the interval between
// null_chars_.
void LSTMRecognizer::LabelsViaThreshold(const NetworkIO& output,
float null_thr,
GenericVector<int>* labels,
GenericVector<int>* xcoords) {
labels->truncate(0);
xcoords->truncate(0);
int width = output.Width();
int t = 0;
// Skip any initial non-char.
while (t < width && NullIsBest(output, null_thr, null_char_, t)) {
++t;
}
while (t < width) {
ASSERT_HOST(!std::isnan(output.f(t)[null_char_]));
int label = output.BestLabel(t, null_char_, null_char_, NULL);
int char_start = t++;
while (t < width && !NullIsBest(output, null_thr, null_char_, t) &&
label == output.BestLabel(t, null_char_, null_char_, NULL)) {
++t;
}
int char_end = t;
labels->push_back(label);
xcoords->push_back(char_start);
// Find the end of the non-char, and compute its score.
while (t < width && NullIsBest(output, null_thr, null_char_, t)) {
++t;
}
if (t > char_end) {
labels->push_back(null_char_);
xcoords->push_back(char_end);
}
}
xcoords->push_back(width);
}
// Converts the network output to a sequence of labels, with scores and
// start x-coords of the character labels. Retains the null_char_ as the
// end x-coord, where already present, otherwise the start of the next
// character is the end.
// The number of labels, scores, and xcoords is always matched, except that
// there is always an additional xcoord for the last end position.
void LSTMRecognizer::LabelsViaCTC(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords) {
labels->truncate(0);
xcoords->truncate(0);
int width = output.Width();
int t = 0;
while (t < width) {
float score = 0.0f;
int label = output.BestLabel(t, &score);
labels->push_back(label);
xcoords->push_back(t);
while (++t < width && output.BestLabel(t, NULL) == label) {
}
LabelsViaReEncode(outputs, labels, xcoords);
}
xcoords->push_back(width);
}
// As LabelsViaCTC except that this function constructs the best path that
......@@ -681,82 +421,6 @@ void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO& output,
xcoords->push_back(width);
}
// Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
// Handles either LSTM labels or direct unichar-ids.
// Score ratio determines the worst ratio between top choice and remainder.
// If target_unicharset is not NULL, attempts to translate to the target
// unicharset, returning NULL on failure.
BLOB_CHOICE_LIST* LSTMRecognizer::GetBlobChoices(
int col, int row, bool debug, const NetworkIO& output,
const UNICHARSET* target_unicharset, int x_start, int x_end,
float score_ratio) {
float rating = 0.0f, certainty = 0.0f;
int label = output.BestChoiceOverRange(x_start, x_end, UNICHAR_SPACE,
null_char_, &rating, &certainty);
int unichar_id = label == null_char_ ? UNICHAR_SPACE : label;
if (debug) {
tprintf("Best choice over range %d,%d=unichar%d=%s r = %g, cert=%g\n",
x_start, x_end, unichar_id, DecodeSingleLabel(label), rating,
certainty);
}
BLOB_CHOICE_LIST* choices = new BLOB_CHOICE_LIST;
BLOB_CHOICE_IT bc_it(choices);
if (!AddBlobChoices(unichar_id, rating, certainty, col, row,
target_unicharset, &bc_it)) {
delete choices;
return NULL;
}
// Get the other choices.
double best_cert = certainty;
for (int c = 0; c < output.NumFeatures(); ++c) {
if (c == label || c == UNICHAR_SPACE || c == null_char_) continue;
// Compute the score over the range.
output.ScoresOverRange(x_start, x_end, c, null_char_, &rating, &certainty);
int unichar_id = c == null_char_ ? UNICHAR_SPACE : c;
if (certainty >= best_cert - score_ratio &&
!AddBlobChoices(unichar_id, rating, certainty, col, row,
target_unicharset, &bc_it)) {
delete choices;
return NULL;
}
}
choices->sort(&BLOB_CHOICE::SortByRating);
if (bc_it.length() > kMaxChoices) {
bc_it.move_to_first();
for (int i = 0; i < kMaxChoices; ++i)
bc_it.forward();
while (!bc_it.at_first()) {
delete bc_it.extract();
bc_it.forward();
}
}
return choices;
}
// Adds to the given iterator, the blob choices for the target_unicharset
// that correspond to the given LSTM unichar_id.
// Returns false if unicharset translation failed.
bool LSTMRecognizer::AddBlobChoices(int unichar_id, float rating,
float certainty, int col, int row,
const UNICHARSET* target_unicharset,
BLOB_CHOICE_IT* bc_it) {
int target_id = unichar_id;
if (target_unicharset != NULL) {
const char* utf8 = GetUnicharset().id_to_unichar(unichar_id);
if (target_unicharset->contains_unichar(utf8)) {
target_id = target_unicharset->unichar_to_id(utf8);
} else {
return false;
}
}
BLOB_CHOICE* choice = new BLOB_CHOICE(target_id, rating, certainty, -1, 1.0f,
static_cast<float>(MAX_INT16), 0.0f,
BCC_STATIC_CLASSIFIER);
choice->set_matrix_cell(col, row);
bc_it->add_after_then_move(choice);
return true;
}
// Returns a string corresponding to the label starting at start. Sets *end
// to the next start and if non-null, *decoded to the unichar id.
const char* LSTMRecognizer::DecodeLabel(const GenericVector<int>& labels,
......
......@@ -169,82 +169,30 @@ class LSTMRecognizer {
bool LoadDictionary(const char* lang, TessdataManager* mgr);
// Recognizes the line image, contained within image_data, returning the
// ratings matrix and matching box_word for each WERD_RES in the output.
// recognized tesseract WERD_RES for the words.
// If invert, tries inverted as well if the normal interpretation doesn't
// produce a good enough result. If use_alternates, the ratings matrix is
// filled with segmentation and classifier alternatives that may be searched
// using the standard beam search, otherwise, just a diagonal and prebuilt
// best_choice. The line_box is used for computing the box_word in the
// output words. Score_ratio is used to determine the classifier alternates.
// If one_word, then a single WERD_RES is formed, regardless of the spaces
// found during recognition.
// If not NULL, we attempt to translate the output to target_unicharset, but
// do not guarantee success, due to mismatches. In that case the output words
// are marked with our UNICHARSET, not the caller's.
// produce a good enough result. The line_box is used for computing the
// box_word in the output words. worst_dict_cert is the worst certainty that
// will be used in a dictionary word.
void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
double worst_dict_cert, bool use_alternates,
const UNICHARSET* target_unicharset, const TBOX& line_box,
float score_ratio, bool one_word,
double worst_dict_cert, const TBOX& line_box,
PointerVector<WERD_RES>* words);
// Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
// corresponding to the network output in outputs, labels, label_coords.
// one_word generates a single word output, that may include spaces inside.
// use_alternates generates alternative BLOB_CHOICEs and segmentation paths,
// with cut-offs determined by scale_factor.
// If not NULL, we attempt to translate the output to target_unicharset, but
// do not guarantee success, due to mismatches. In that case the output words
// are marked with our UNICHARSET, not the caller's.
void WordsFromOutputs(const NetworkIO& outputs,
const GenericVector<int>& labels,
const GenericVector<int> label_coords,
const TBOX& line_box, bool debug, bool use_alternates,
bool one_word, float score_ratio, float scale_factor,
const UNICHARSET* target_unicharset,
PointerVector<WERD_RES>* words);
// Helper computes min and mean best results in the output.
void OutputStats(const NetworkIO& outputs,
float* min_output, float* mean_output, float* sd);
// Recognizes the image_data, returning the labels,
// scores, and corresponding pairs of start, end x-coords in coords.
// If label_threshold is positive, uses it for making the labels, otherwise
// uses standard ctc. Returned in scale_factor is the reduction factor
// Returned in scale_factor is the reduction factor
// between the image and the output coords, for computing bounding boxes.
// If re_invert is true, the input is inverted back to its original
// photometric interpretation if inversion is attempted but fails to
// improve the results. This ensures that outputs contains the correct
// forward outputs for the best photometric interpretation.
// inputs is filled with the used inputs to the network, and if not null,
// target boxes is filled with scaled truth boxes if present in image_data.
// inputs is filled with the used inputs to the network.
bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
bool re_invert, float label_threshold, float* scale_factor,
NetworkIO* inputs, NetworkIO* outputs);
// Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
// line_box should be the bounding box of the line image in the main image,
// outputs the output of the network,
// [word_start, word_end) the interval over which to convert,
// score_ratio for choosing alternate classifier choices,
// use_alternates to control generation of alternative segmentations,
// labels, label_coords, scale_factor from RecognizeLine above.
// If target_unicharset is not NULL, attempts to translate the internal
// unichar_ids to the target_unicharset, but falls back to untranslated ids
// if the translation should fail.
WERD_RES* WordFromOutput(const TBOX& line_box, const NetworkIO& outputs,
int word_start, int word_end, float score_ratio,
float space_certainty, bool debug,
bool use_alternates,
const UNICHARSET* target_unicharset,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
float scale_factor);
// Sets up a word with the ratings matrix and fake blobs with boxes in the
// right places.
WERD_RES* InitializeWord(const TBOX& line_box, int word_start, int word_end,
float space_certainty, bool use_alternates,
const UNICHARSET* target_unicharset,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
float scale_factor);
bool re_invert, float* scale_factor, NetworkIO* inputs,
NetworkIO* outputs);
// Converts an array of labels to utf-8, whether or not the labels are
// augmented with character boundaries.
......@@ -287,28 +235,8 @@ class LSTMRecognizer {
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
GenericVector<int>* labels,
void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Converts the network output to a sequence of labels, using a threshold
// on the null_char_ to determine character boundaries. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The label output is the one with the highest score in the interval between
// null_chars_.
void LabelsViaThreshold(const NetworkIO& output,
float null_threshold,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Converts the network output to a sequence of labels, with scores and
// start x-coords of the character labels. Retains the null_char_ character as
// the end x-coord, where already present, otherwise the start of the next
// character is the end.
// The number of labels, scores, and xcoords is always matched, except that
// there is always an additional xcoord for the last end position.
void LabelsViaCTC(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// As LabelsViaCTC except that this function constructs the best path that
// contains only legal sequences of subcodes for recoder_.
void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
......@@ -320,23 +248,6 @@ class LSTMRecognizer {
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
// Handles either LSTM labels or direct unichar-ids.
// Score ratio determines the worst ratio between top choice and remainder.
// If target_unicharset is not NULL, attempts to translate to the target
// unicharset, returning NULL on failure.
BLOB_CHOICE_LIST* GetBlobChoices(int col, int row, bool debug,
const NetworkIO& output,
const UNICHARSET* target_unicharset,
int x_start, int x_end, float score_ratio);
// Adds to the given iterator, the blob choices for the target_unicharset
// that correspond to the given LSTM unichar_id.
// Returns false if unicharset translation failed.
bool AddBlobChoices(int unichar_id, float rating, float certainty, int col,
int row, const UNICHARSET* target_unicharset,
BLOB_CHOICE_IT* bc_it);
// Returns a string corresponding to the label starting at start. Sets *end
// to the next start and if non-null, *decoded to the unichar id.
const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
......
......@@ -180,13 +180,9 @@ bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index,
weight_range_ = weight_range;
learning_rate_ = learning_rate;
momentum_ = momentum;
int num_outputs = null_char_ == GetUnicharset().size()
? null_char_ + 1
: GetUnicharset().size();
if (IsRecoding()) num_outputs = recoder_.code_range();
if (!NetworkBuilder::InitNetwork(num_outputs, network_spec, append_index,
net_flags, weight_range, &randomizer_,
&network_)) {
if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec,
append_index, net_flags, weight_range,
&randomizer_, &network_)) {
return false;
}
network_str_ += network_spec;
......@@ -852,7 +848,7 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata,
float image_scale;
NetworkIO inputs;
bool invert = trainingdata->boxes().empty();
if (!RecognizeLine(*trainingdata, invert, debug, invert, 0.0f, &image_scale,
if (!RecognizeLine(*trainingdata, invert, debug, invert, &image_scale,
&inputs, fwd_outputs)) {
tprintf("Image not trainable\n");
return UNENCODABLE;
......@@ -875,10 +871,10 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata,
}
GenericVector<int> ocr_labels;
GenericVector<int> xcoords;
LabelsFromOutputs(*fwd_outputs, 0.0f, &ocr_labels, &xcoords);
LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
// CTC does not produce correct target labels to begin with.
if (loss_type != LT_CTC) {
LabelsFromOutputs(*targets, 0.0f, &truth_labels, &xcoords);
LabelsFromOutputs(*targets, &truth_labels, &xcoords);
}
if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
*targets)) {
......@@ -1021,8 +1017,9 @@ void LSTMTrainer::SetUnicharsetProperties(const STRING& script_dir) {
tprintf("Failed to load radical-stroke info from: %s\n",
filename.string());
}
training_flags_ &= ~TF_COMPRESS_UNICHARSET;
}
training_flags_ |= TF_COMPRESS_UNICHARSET;
recoder_.SetupPassThrough(GetUnicharset());
}
// Outputs the string and periodically displays the given network inputs
......@@ -1043,7 +1040,7 @@ bool LSTMTrainer::DebugLSTMTraining(const NetworkIO& inputs,
// Get class labels, xcoords and string.
GenericVector<int> labels;
GenericVector<int> xcoords;
LabelsFromOutputs(outputs, 0.0f, &labels, &xcoords);
LabelsFromOutputs(outputs, &labels, &xcoords);
STRING text = DecodeLabels(labels);
tprintf("Iteration %d: ALIGNED TRUTH : %s\n",
training_iteration(), text.string());
......
......@@ -125,25 +125,6 @@ void Wordrec::SegSearch(WERD_RES* word_res,
}
}
// Setup and run just the initial segsearch on an established matrix,
// without doing any additional chopping or joining.
void Wordrec::WordSearch(WERD_RES* word_res) {
LMPainPoints pain_points(segsearch_max_pain_points,
segsearch_max_char_wh_ratio,
assume_fixed_pitch_char_segment,
&getDict(), segsearch_debug_level);
GenericVector<SegSearchPending> pending;
BestChoiceBundle best_choice_bundle(word_res->ratings->dimension());
// Run Segmentation Search.
InitialSegSearch(word_res, &pain_points, &pending, &best_choice_bundle, NULL);
if (segsearch_debug_level > 0) {
tprintf("Ending ratings matrix%s:\n",
wordrec_enable_assoc ? " (with assoc)" : "");
word_res->ratings->print(getDict().getUnicharset());
}
}
// Setup and run just the initial segsearch on an established matrix,
// without doing any additional chopping or joining.
// (Internal factored version that can be used as part of the main SegSearch.)
......
......@@ -270,9 +270,6 @@ class Wordrec : public Classify {
void SegSearch(WERD_RES* word_res,
BestChoiceBundle* best_choice_bundle,
BlamerBundle* blamer_bundle);
// Setup and run just the initial segsearch on an established matrix,
// without doing any additional chopping or joining.
void WordSearch(WERD_RES* word_res);
// Setup and run just the initial segsearch on an established matrix,
// without doing any additional chopping or joining.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册