lstmrecognizer.h 18.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
///////////////////////////////////////////////////////////////////////
// File:        lstmrecognizer.h
// Description: Top-level line recognizer class for LSTM-based networks.
// Author:      Ray Smith
// Created:     Thu May 02 08:57:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////

#ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
#define TESSERACT_LSTM_LSTMRECOGNIZER_H_

#include "ccutil.h"
#include "helpers.h"
#include "imagedata.h"
#include "matrix.h"
#include "network.h"
#include "networkscratch.h"
#include "recodebeam.h"
#include "series.h"
#include "strngs.h"
#include "unicharcompress.h"

class BLOB_CHOICE_IT;
struct Pix;
class ROW_RES;
class ScrollView;
class TBOX;
class WERD_RES;

namespace tesseract {

class Dict;
class ImageData;

// Enum indicating training mode control flags.
enum TrainingFlags {
  TF_INT_MODE = 1,
  TF_AUTO_HARDEN = 2,
  TF_ROUND_ROBIN_TRAINING = 16,
  TF_COMPRESS_UNICHARSET = 64,
};

// Top-level line recognizer class for LSTM-based networks.
// Note that a sub-class, LSTMTrainer is used for training.
class LSTMRecognizer {
 public:
  LSTMRecognizer();
  ~LSTMRecognizer();

  int NumOutputs() const {
    return network_->NumOutputs();
  }
  int training_iteration() const {
    return training_iteration_;
  }
  int sample_iteration() const {
    return sample_iteration_;
  }
  double learning_rate() const {
    return learning_rate_;
  }
  bool IsHardening() const {
    return (training_flags_ & TF_AUTO_HARDEN) != 0;
  }
  LossType OutputLossType() const {
    if (network_ == nullptr) return LT_NONE;
    StaticShape shape;
    shape = network_->OutputShape(shape);
    return shape.loss_type();
  }
  bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
  bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
  // True if recoder_ is active to re-encode text to a smaller space.
  bool IsRecoding() const {
    return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
  }
  // Returns the cache strategy for the DocumentCache.
  CachingStrategy CacheStrategy() const {
    return training_flags_ & TF_ROUND_ROBIN_TRAINING ? CS_ROUND_ROBIN
                                                     : CS_SEQUENTIAL;
  }
  // Returns true if the network is a TensorFlow network.
  bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
  // Returns a vector of layer ids that can be passed to other layer functions
  // to access a specific layer.
  GenericVector<STRING> EnumerateLayers() const {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
    Series* series = reinterpret_cast<Series*>(network_);
    GenericVector<STRING> layers;
    series->EnumerateLayers(NULL, &layers);
    return layers;
  }
  // Returns a specific layer from its id (from EnumerateLayers).
  Network* GetLayer(const STRING& id) const {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
    ASSERT_HOST(id.length() > 1 && id[0] == ':');
    Series* series = reinterpret_cast<Series*>(network_);
    return series->GetLayer(&id[1]);
  }
  // Returns the learning rate of the layer from its id.
  float GetLayerLearningRate(const STRING& id) const {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
    if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
      ASSERT_HOST(id.length() > 1 && id[0] == ':');
      Series* series = reinterpret_cast<Series*>(network_);
      return series->LayerLearningRate(&id[1]);
    } else {
      return learning_rate_;
    }
  }
  // Multiplies the all the learning rate(s) by the given factor.
  void ScaleLearningRate(double factor) {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
    learning_rate_ *= factor;
    if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
      GenericVector<STRING> layers = EnumerateLayers();
      for (int i = 0; i < layers.size(); ++i) {
        ScaleLayerLearningRate(layers[i], factor);
      }
    }
  }
  // Multiplies the learning rate of the layer with id, by the given factor.
  void ScaleLayerLearningRate(const STRING& id, double factor) {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
    ASSERT_HOST(id.length() > 1 && id[0] == ':');
    Series* series = reinterpret_cast<Series*>(network_);
    series->ScaleLayerLearningRate(&id[1], factor);
  }

  // True if the network is using adagrad to train.
  bool IsUsingAdaGrad() const { return network_->TestFlag(NF_ADA_GRAD); }
  // Provides access to the UNICHARSET that this classifier works with.
  const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
144 145
  // Provides access to the Dict that this classifier works with.
  const Dict* GetDict() const { return dict_; }
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
  // Sets the sample iteration to the given value. The sample_iteration_
  // determines the seed for the random number generator. The training
  // iteration is incremented only by a successful training iteration.
  void SetIteration(int iteration) {
    sample_iteration_ = iteration;
  }
  // Accessors for textline image normalization.
  int NumInputs() const {
    return network_->NumInputs();
  }
  int null_char() const { return null_char_; }

  // Writes to the given file. Returns false in case of error.
  bool Serialize(TFile* fp) const;
  // Reads from the given file. Returns false in case of error.
  // If swap is true, assumes a big/little-endian swap is needed.
  bool DeSerialize(bool swap, TFile* fp);
  // Loads the dictionary if possible from the traineddata file.
  // Prints a warning message, and returns false but otherwise fails silently
  // and continues to work without it if loading fails.
  // Note that dictionary load is independent from DeSerialize, but dependent
  // on the unicharset matching. This enables training to deserialize a model
  // from checkpoint or restore without having to go back and reload the
  // dictionary.
  bool LoadDictionary(const char* data_file_name, const char* lang);

  // Recognizes the line image, contained within image_data, returning the
  // ratings matrix and matching box_word for each WERD_RES in the output.
  // If invert, tries inverted as well if the normal interpretation doesn't
  // produce a good enough result. If use_alternates, the ratings matrix is
  // filled with segmentation and classifier alternatives that may be searched
  // using the standard beam search, otherwise, just a diagonal and prebuilt
  // best_choice. The line_box is used for computing the box_word in the
  // output words. Score_ratio is used to determine the classifier alternates.
  // If one_word, then a single WERD_RES is formed, regardless of the spaces
  // found during recognition.
  // If not NULL, we attempt to translate the output to target_unicharset, but
  // do not guarantee success, due to mismatches. In that case the output words
  // are marked with our UNICHARSET, not the caller's.
  void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
                     double worst_dict_cert, bool use_alternates,
                     const UNICHARSET* target_unicharset, const TBOX& line_box,
                     float score_ratio, bool one_word,
                     PointerVector<WERD_RES>* words);
  // Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
  // corresponding to the network output in outputs, labels, label_coords.
  // one_word generates a single word output, that may include spaces inside.
  // use_alternates generates alternative BLOB_CHOICEs and segmentation paths,
  // with cut-offs determined by scale_factor.
  // If not NULL, we attempt to translate the output to target_unicharset, but
  // do not guarantee success, due to mismatches. In that case the output words
  // are marked with our UNICHARSET, not the caller's.
  void WordsFromOutputs(const NetworkIO& outputs,
                        const GenericVector<int>& labels,
                        const GenericVector<int> label_coords,
                        const TBOX& line_box, bool debug, bool use_alternates,
                        bool one_word, float score_ratio, float scale_factor,
                        const UNICHARSET* target_unicharset,
                        PointerVector<WERD_RES>* words);

  // Helper computes min and mean best results in the output.
  void OutputStats(const NetworkIO& outputs,
                   float* min_output, float* mean_output, float* sd);
  // Recognizes the image_data, returning the labels,
  // scores, and corresponding pairs of start, end x-coords in coords.
  // If label_threshold is positive, uses it for making the labels, otherwise
  // uses standard ctc. Returned in scale_factor is the reduction factor
  // between the image and the output coords, for computing bounding boxes.
S
Stefan Weil 已提交
214
  // If re_invert is true, the input is inverted back to its original
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
  // photometric interpretation if inversion is attempted but fails to
  // improve the results. This ensures that outputs contains the correct
  // forward outputs for the best photometric interpretation.
  // inputs is filled with the used inputs to the network, and if not null,
  // target boxes is filled with scaled truth boxes if present in image_data.
  bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
                     bool re_invert, float label_threshold, float* scale_factor,
                     NetworkIO* inputs, NetworkIO* outputs);
  // Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
  // line_box should be the bounding box of the line image in the main image,
  // outputs the output of the network,
  // [word_start, word_end) the interval over which to convert,
  // score_ratio for choosing alternate classifier choices,
  // use_alternates to control generation of alternative segmentations,
  // labels, label_coords, scale_factor from RecognizeLine above.
  // If target_unicharset is not NULL, attempts to translate the internal
  // unichar_ids to the target_unicharset, but falls back to untranslated ids
  // if the translation should fail.
  WERD_RES* WordFromOutput(const TBOX& line_box, const NetworkIO& outputs,
                           int word_start, int word_end, float score_ratio,
                           float space_certainty, bool debug,
                           bool use_alternates,
                           const UNICHARSET* target_unicharset,
                           const GenericVector<int>& labels,
                           const GenericVector<int>& label_coords,
                           float scale_factor);
  // Sets up a word with the ratings matrix and fake blobs with boxes in the
  // right places.
  WERD_RES* InitializeWord(const TBOX& line_box, int word_start, int word_end,
                           float space_certainty, bool use_alternates,
                           const UNICHARSET* target_unicharset,
                           const GenericVector<int>& labels,
                           const GenericVector<int>& label_coords,
                           float scale_factor);

  // Converts an array of labels to utf-8, whether or not the labels are
  // augmented with character boundaries.
  STRING DecodeLabels(const GenericVector<int>& labels);

  // Displays the forward results in a window with the characters and
  // boundaries as determined by the labels and label_coords.
  void DisplayForward(const NetworkIO& inputs,
                      const GenericVector<int>& labels,
                      const GenericVector<int>& label_coords,
                      const char* window_name,
                      ScrollView** window);

 protected:
  // Sets the random seed from the sample_iteration_;
  void SetRandomSeed() {
    inT64 seed = static_cast<inT64>(sample_iteration_) * 0x10000001;
    randomizer_.set_seed(seed);
    randomizer_.IntRand();
  }

  // Displays the labels and cuts at the corresponding xcoords.
  // Size of labels should match xcoords.
  void DisplayLSTMOutput(const GenericVector<int>& labels,
                         const GenericVector<int>& xcoords,
                         int height, ScrollView* window);

  // Prints debug output detailing the activation path that is implied by the
  // xcoords.
  void DebugActivationPath(const NetworkIO& outputs,
                           const GenericVector<int>& labels,
                           const GenericVector<int>& xcoords);

  // Prints debug output detailing activations and 2nd choice over a range
  // of positions.
  void DebugActivationRange(const NetworkIO& outputs, const char* label,
                            int best_choice, int x_start, int x_end);

  // Converts the network output to a sequence of labels. Outputs labels, scores
  // and start xcoords of each char, and each null_char_, with an additional
  // final xcoord for the end of the output.
  // The conversion method is determined by internal state.
  void LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
                         GenericVector<int>* labels,
                         GenericVector<int>* xcoords);
  // Converts the network output to a sequence of labels, using a threshold
  // on the null_char_ to determine character boundaries. Outputs labels, scores
  // and start xcoords of each char, and each null_char_, with an additional
  // final xcoord for the end of the output.
  // The label output is the one with the highest score in the interval between
  // null_chars_.
  void LabelsViaThreshold(const NetworkIO& output,
                          float null_threshold,
                          GenericVector<int>* labels,
                          GenericVector<int>* xcoords);
  // Converts the network output to a sequence of labels, with scores and
  // start x-coords of the character labels. Retains the null_char_ character as
  // the end x-coord, where already present, otherwise the start of the next
  // character is the end.
  // The number of labels, scores, and xcoords is always matched, except that
  // there is always an additional xcoord for the last end position.
  void LabelsViaCTC(const NetworkIO& output,
                    GenericVector<int>* labels,
                    GenericVector<int>* xcoords);
  // As LabelsViaCTC except that this function constructs the best path that
  // contains only legal sequences of subcodes for recoder_.
  void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
                         GenericVector<int>* xcoords);
  // Converts the network output to a sequence of labels, with scores, using
  // the simple character model (each position is a char, and the null_char_ is
  // mainly intended for tail padding.)
  void LabelsViaSimpleText(const NetworkIO& output,
                           GenericVector<int>* labels,
                           GenericVector<int>* xcoords);

  // Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
  // Handles either LSTM labels or direct unichar-ids.
  // Score ratio determines the worst ratio between top choice and remainder.
  // If target_unicharset is not NULL, attempts to translate to the target
  // unicharset, returning NULL on failure.
  BLOB_CHOICE_LIST* GetBlobChoices(int col, int row, bool debug,
                                   const NetworkIO& output,
                                   const UNICHARSET* target_unicharset,
                                   int x_start, int x_end, float score_ratio);

  // Adds to the given iterator, the blob choices for the target_unicharset
  // that correspond to the given LSTM unichar_id.
  // Returns false if unicharset translation failed.
  bool AddBlobChoices(int unichar_id, float rating, float certainty, int col,
                      int row, const UNICHARSET* target_unicharset,
                      BLOB_CHOICE_IT* bc_it);

  // Returns a string corresponding to the label starting at start. Sets *end
  // to the next start and if non-null, *decoded to the unichar id.
  const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
                          int* decoded);

  // Returns a string corresponding to a given single label id, falling back to
  // a default of ".." for part of a multi-label unichar-id.
  const char* DecodeSingleLabel(int label);

 protected:
  // The network hierarchy.
  Network* network_;
  // The unicharset. Only the unicharset element is serialized.
  // Has to be a CCUtil, so Dict can point to it.
  CCUtil ccutil_;
S
Stefan Weil 已提交
356
  // For backward compatibility, recoder_ is serialized iff
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
  // training_flags_ & TF_COMPRESS_UNICHARSET.
  // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
  UnicharCompress recoder_;

  // ==Training parameters that are serialized to provide a record of them.==
  STRING network_str_;
  // Flags used to determine the training method of the network.
  // See enum TrainingFlags above.
  inT32 training_flags_;
  // Number of actual backward training steps used.
  inT32 training_iteration_;
  // Index into training sample set. sample_iteration >= training_iteration_.
  inT32 sample_iteration_;
  // Index in softmax of null character. May take the value UNICHAR_BROKEN or
  // ccutil_.unicharset.size().
  inT32 null_char_;
  // Range used for the initial random numbers in the weights.
  float weight_range_;
  // Learning rate and momentum multipliers of deltas in backprop.
  float learning_rate_;
  float momentum_;

  // === NOT SERIALIZED.
  TRand randomizer_;
  NetworkScratch scratch_space_;
  // Language model (optional) to use with the beam search.
  Dict* dict_;
  // Beam search held between uses to optimize memory allocation/use.
  RecodeBeamSearch* search_;

  // == Debugging parameters.==
  // Recognition debug display window.
  ScrollView* debug_win_;
};

}  // namespace tesseract.

#endif  // TESSERACT_LSTM_LSTMRECOGNIZER_H_