lstmrecognizer.h 12.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
///////////////////////////////////////////////////////////////////////
// File:        lstmrecognizer.h
// Description: Top-level line recognizer class for LSTM-based networks.
// Author:      Ray Smith
// Created:     Thu May 02 08:57:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////

#ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
#define TESSERACT_LSTM_LSTMRECOGNIZER_H_

#include "ccutil.h"
#include "helpers.h"
#include "imagedata.h"
#include "matrix.h"
#include "network.h"
#include "networkscratch.h"
#include "recodebeam.h"
#include "series.h"
#include "strngs.h"
#include "unicharcompress.h"

class BLOB_CHOICE_IT;
struct Pix;
class ROW_RES;
class ScrollView;
class TBOX;
class WERD_RES;

namespace tesseract {

class Dict;
class ImageData;

// Enum indicating training mode control flags.
enum TrainingFlags {
  TF_INT_MODE = 1,
  TF_AUTO_HARDEN = 2,
  TF_ROUND_ROBIN_TRAINING = 16,
  TF_COMPRESS_UNICHARSET = 64,
};

// Top-level line recognizer class for LSTM-based networks.
// Note that a sub-class, LSTMTrainer is used for training.
class LSTMRecognizer {
 public:
  LSTMRecognizer();
  ~LSTMRecognizer();

  int NumOutputs() const {
    return network_->NumOutputs();
  }
  int training_iteration() const {
    return training_iteration_;
  }
  int sample_iteration() const {
    return sample_iteration_;
  }
  double learning_rate() const {
    return learning_rate_;
  }
  bool IsHardening() const {
    return (training_flags_ & TF_AUTO_HARDEN) != 0;
  }
  LossType OutputLossType() const {
    if (network_ == nullptr) return LT_NONE;
    StaticShape shape;
    shape = network_->OutputShape(shape);
    return shape.loss_type();
  }
  bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
  bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
  // True if recoder_ is active to re-encode text to a smaller space.
  bool IsRecoding() const {
    return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
  }
  // Returns the cache strategy for the DocumentCache.
  CachingStrategy CacheStrategy() const {
    return training_flags_ & TF_ROUND_ROBIN_TRAINING ? CS_ROUND_ROBIN
                                                     : CS_SEQUENTIAL;
  }
  // Returns true if the network is a TensorFlow network.
  bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
  // Returns a vector of layer ids that can be passed to other layer functions
  // to access a specific layer.
  GenericVector<STRING> EnumerateLayers() const {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
98
    Series* series = static_cast<Series*>(network_);
99 100 101 102 103 104 105 106
    GenericVector<STRING> layers;
    series->EnumerateLayers(NULL, &layers);
    return layers;
  }
  // Returns a specific layer from its id (from EnumerateLayers).
  Network* GetLayer(const STRING& id) const {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
    ASSERT_HOST(id.length() > 1 && id[0] == ':');
107
    Series* series = static_cast<Series*>(network_);
108 109 110 111 112 113 114
    return series->GetLayer(&id[1]);
  }
  // Returns the learning rate of the layer from its id.
  float GetLayerLearningRate(const STRING& id) const {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
    if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
      ASSERT_HOST(id.length() > 1 && id[0] == ':');
115
      Series* series = static_cast<Series*>(network_);
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
      return series->LayerLearningRate(&id[1]);
    } else {
      return learning_rate_;
    }
  }
  // Multiplies the all the learning rate(s) by the given factor.
  void ScaleLearningRate(double factor) {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
    learning_rate_ *= factor;
    if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
      GenericVector<STRING> layers = EnumerateLayers();
      for (int i = 0; i < layers.size(); ++i) {
        ScaleLayerLearningRate(layers[i], factor);
      }
    }
  }
  // Multiplies the learning rate of the layer with id, by the given factor.
  void ScaleLayerLearningRate(const STRING& id, double factor) {
    ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
    ASSERT_HOST(id.length() > 1 && id[0] == ':');
136
    Series* series = static_cast<Series*>(network_);
137 138 139 140 141 142 143
    series->ScaleLayerLearningRate(&id[1], factor);
  }

  // True if the network is using adagrad to train.
  bool IsUsingAdaGrad() const { return network_->TestFlag(NF_ADA_GRAD); }
  // Provides access to the UNICHARSET that this classifier works with.
  const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
144 145
  // Provides access to the Dict that this classifier works with.
  const Dict* GetDict() const { return dict_; }
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
  // Sets the sample iteration to the given value. The sample_iteration_
  // determines the seed for the random number generator. The training
  // iteration is incremented only by a successful training iteration.
  void SetIteration(int iteration) {
    sample_iteration_ = iteration;
  }
  // Accessors for textline image normalization.
  int NumInputs() const {
    return network_->NumInputs();
  }
  int null_char() const { return null_char_; }

  // Writes to the given file. Returns false in case of error.
  bool Serialize(TFile* fp) const;
  // Reads from the given file. Returns false in case of error.
161
  bool DeSerialize(TFile* fp);
162 163 164 165 166 167 168
  // Loads the dictionary if possible from the traineddata file.
  // Prints a warning message, and returns false but otherwise fails silently
  // and continues to work without it if loading fails.
  // Note that dictionary load is independent from DeSerialize, but dependent
  // on the unicharset matching. This enables training to deserialize a model
  // from checkpoint or restore without having to go back and reload the
  // dictionary.
169
  bool LoadDictionary(const char* lang, TessdataManager* mgr);
170 171

  // Recognizes the line image, contained within image_data, returning the
172
  // recognized tesseract WERD_RES for the words.
173
  // If invert, tries inverted as well if the normal interpretation doesn't
174 175 176
  // produce a good enough result. The line_box is used for computing the
  // box_word in the output words. worst_dict_cert is the worst certainty that
  // will be used in a dictionary word.
177
  void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
178
                     double worst_dict_cert, const TBOX& line_box,
179 180 181 182 183 184 185
                     PointerVector<WERD_RES>* words);

  // Helper computes min and mean best results in the output.
  void OutputStats(const NetworkIO& outputs,
                   float* min_output, float* mean_output, float* sd);
  // Recognizes the image_data, returning the labels,
  // scores, and corresponding pairs of start, end x-coords in coords.
186
  // Returned in scale_factor is the reduction factor
187
  // between the image and the output coords, for computing bounding boxes.
S
Stefan Weil 已提交
188
  // If re_invert is true, the input is inverted back to its original
189 190 191
  // photometric interpretation if inversion is attempted but fails to
  // improve the results. This ensures that outputs contains the correct
  // forward outputs for the best photometric interpretation.
192
  // inputs is filled with the used inputs to the network.
193
  bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
194 195
                     bool re_invert, float* scale_factor, NetworkIO* inputs,
                     NetworkIO* outputs);
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237

  // Converts an array of labels to utf-8, whether or not the labels are
  // augmented with character boundaries.
  STRING DecodeLabels(const GenericVector<int>& labels);

  // Displays the forward results in a window with the characters and
  // boundaries as determined by the labels and label_coords.
  void DisplayForward(const NetworkIO& inputs,
                      const GenericVector<int>& labels,
                      const GenericVector<int>& label_coords,
                      const char* window_name,
                      ScrollView** window);

 protected:
  // Sets the random seed from the sample_iteration_;
  void SetRandomSeed() {
    inT64 seed = static_cast<inT64>(sample_iteration_) * 0x10000001;
    randomizer_.set_seed(seed);
    randomizer_.IntRand();
  }

  // Displays the labels and cuts at the corresponding xcoords.
  // Size of labels should match xcoords.
  void DisplayLSTMOutput(const GenericVector<int>& labels,
                         const GenericVector<int>& xcoords,
                         int height, ScrollView* window);

  // Prints debug output detailing the activation path that is implied by the
  // xcoords.
  void DebugActivationPath(const NetworkIO& outputs,
                           const GenericVector<int>& labels,
                           const GenericVector<int>& xcoords);

  // Prints debug output detailing activations and 2nd choice over a range
  // of positions.
  void DebugActivationRange(const NetworkIO& outputs, const char* label,
                            int best_choice, int x_start, int x_end);

  // Converts the network output to a sequence of labels. Outputs labels, scores
  // and start xcoords of each char, and each null_char_, with an additional
  // final xcoord for the end of the output.
  // The conversion method is determined by internal state.
238
  void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
                         GenericVector<int>* xcoords);
  // As LabelsViaCTC except that this function constructs the best path that
  // contains only legal sequences of subcodes for recoder_.
  void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
                         GenericVector<int>* xcoords);
  // Converts the network output to a sequence of labels, with scores, using
  // the simple character model (each position is a char, and the null_char_ is
  // mainly intended for tail padding.)
  void LabelsViaSimpleText(const NetworkIO& output,
                           GenericVector<int>* labels,
                           GenericVector<int>* xcoords);

  // Returns a string corresponding to the label starting at start. Sets *end
  // to the next start and if non-null, *decoded to the unichar id.
  const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
                          int* decoded);

  // Returns a string corresponding to a given single label id, falling back to
  // a default of ".." for part of a multi-label unichar-id.
  const char* DecodeSingleLabel(int label);

 protected:
  // The network hierarchy.
  Network* network_;
  // The unicharset. Only the unicharset element is serialized.
  // Has to be a CCUtil, so Dict can point to it.
  CCUtil ccutil_;
S
Stefan Weil 已提交
266
  // For backward compatibility, recoder_ is serialized iff
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
  // training_flags_ & TF_COMPRESS_UNICHARSET.
  // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
  UnicharCompress recoder_;

  // ==Training parameters that are serialized to provide a record of them.==
  STRING network_str_;
  // Flags used to determine the training method of the network.
  // See enum TrainingFlags above.
  inT32 training_flags_;
  // Number of actual backward training steps used.
  inT32 training_iteration_;
  // Index into training sample set. sample_iteration >= training_iteration_.
  inT32 sample_iteration_;
  // Index in softmax of null character. May take the value UNICHAR_BROKEN or
  // ccutil_.unicharset.size().
  inT32 null_char_;
  // Range used for the initial random numbers in the weights.
  float weight_range_;
  // Learning rate and momentum multipliers of deltas in backprop.
  float learning_rate_;
  float momentum_;

  // === NOT SERIALIZED.
  TRand randomizer_;
  NetworkScratch scratch_space_;
  // Language model (optional) to use with the beam search.
  Dict* dict_;
  // Beam search held between uses to optimize memory allocation/use.
  RecodeBeamSearch* search_;

  // == Debugging parameters.==
  // Recognition debug display window.
  ScrollView* debug_win_;
};

}  // namespace tesseract.

#endif  // TESSERACT_LSTM_LSTMRECOGNIZER_H_