提交 ce76d1c5 编写于 作者: R Ray Smith

Fixes to training process to allow incremental training from a recognition model

上级 9d905671
......@@ -64,6 +64,7 @@ void Tesseract::TrainLineRecognizer(const STRING& input_imagename,
return;
}
TrainFromBoxes(boxes, texts, block_list, &images);
images.Shuffle();
if (!images.SaveDocument(lstmf_name.string(), NULL)) {
tprintf("Failed to write training data to %s!\n", lstmf_name.string());
}
......@@ -79,7 +80,10 @@ void Tesseract::TrainFromBoxes(const GenericVector<TBOX>& boxes,
int box_count = boxes.size();
// Process all the text lines in this page, as defined by the boxes.
int end_box = 0;
for (int start_box = 0; start_box < box_count; start_box = end_box) {
// Don't let \t, which marks newlines in the box file, get into the line
// content, as that makes the line unusable in training.
while (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
for (int start_box = end_box; start_box < box_count; start_box = end_box) {
// Find the textline of boxes starting at start and their bounding box.
TBOX line_box = boxes[start_box];
STRING line_str = texts[start_box];
......@@ -115,7 +119,9 @@ void Tesseract::TrainFromBoxes(const GenericVector<TBOX>& boxes,
}
if (imagedata != NULL)
training_data->AddPageToDocument(imagedata);
if (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
// Don't let \t, which marks newlines in the box file, get into the line
// content, as that makes the line unusable in training.
while (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
}
}
......
......@@ -55,6 +55,8 @@ bool ReadAllBoxes(int target_page, bool skip_blanks, const STRING& filename,
GenericVector<char> box_data;
if (!tesseract::LoadDataFromFile(BoxFileName(filename), &box_data))
return false;
// Convert the array of bytes to a string, so it can be used by the parser.
box_data.push_back('\0');
return ReadMemBoxes(target_page, skip_blanks, &box_data[0], boxes, texts,
box_texts, pages);
}
......
......@@ -24,18 +24,18 @@
#include "imagedata.h"
#if defined(__MINGW32__)
#include <unistd.h>
#else
#include <thread>
#endif
#include "allheaders.h"
#include "boxread.h"
#include "callcpp.h"
#include "helpers.h"
#include "tprintf.h"
#if defined(__MINGW32__)
# include <unistd.h>
#else
# include <thread>
#endif
// Number of documents to read ahead while training. Doesn't need to be very
// large.
const int kMaxReadAhead = 8;
......@@ -496,6 +496,21 @@ inT64 DocumentData::UnCache() {
return memory_saved;
}
// Shuffles all the pages in the document.
void DocumentData::Shuffle() {
TRand random;
// Different documents get shuffled differently, but the same for the same
// name.
random.set_seed(document_name_.string());
int num_pages = pages_.size();
// Execute one random swap for each page in the document.
for (int i = 0; i < num_pages; ++i) {
int src = random.IntRand() % num_pages;
int dest = random.IntRand() % num_pages;
std::swap(pages_[src], pages_[dest]);
}
}
// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_
// starting at index pages_offset_.
bool DocumentData::ReCachePages() {
......
......@@ -266,6 +266,8 @@ class DocumentData {
// Removes all pages from memory and frees the memory, but does not forget
// the document metadata. Returns the memory saved.
inT64 UnCache();
// Shuffles all the pages in the document.
void Shuffle();
private:
// Sets the value of total_pages_ behind a mutex.
......
......@@ -529,13 +529,12 @@ void WERD_RES::FilterWordChoices(int debug_level) {
if (choice->unichar_id(i) != best_choice->unichar_id(j) &&
choice->certainty(i) - best_choice->certainty(j) < threshold) {
if (debug_level >= 2) {
STRING label;
label.add_str_int("\nDiscarding bad choice #", index);
choice->print(label.string());
tprintf("i %d j %d Chunk %d Choice->Blob[i].Certainty %.4g"
" BestChoice->ChunkCertainty[Chunk] %g Threshold %g\n",
i, j, chunk, choice->certainty(i),
best_choice->certainty(j), threshold);
choice->print("WorstCertaintyDiffWorseThan");
tprintf(
"i %d j %d Choice->Blob[i].Certainty %.4g"
" WorstOtherChoiceCertainty %g Threshold %g\n",
i, j, choice->certainty(i), best_choice->certainty(j), threshold);
tprintf("Discarding bad choice #%d\n", index);
}
delete it.extract();
break;
......
......@@ -363,8 +363,7 @@ inline bool LoadDataFromFile(const STRING& filename,
fseek(fp, 0, SEEK_END);
size_t size = ftell(fp);
fseek(fp, 0, SEEK_SET);
// Pad with a 0, just in case we treat the result as a string.
data->init_to_size(static_cast<int>(size) + 1, 0);
data->init_to_size(static_cast<int>(size), 0);
bool result = fread(&(*data)[0], 1, size, fp) == size;
fclose(fp);
return result;
......@@ -380,6 +379,17 @@ inline bool SaveDataToFile(const GenericVector<char>& data,
fclose(fp);
return result;
}
// Reads a file as a vector of STRING.
inline bool LoadFileLinesToStrings(const STRING& filename,
GenericVector<STRING>* lines) {
GenericVector<char> data;
if (!LoadDataFromFile(filename.string(), &data)) {
return false;
}
STRING lines_str(&data[0], data.size());
lines_str.split('\n', lines);
return true;
}
template <typename T>
bool cmp_eq(T const & t1, T const & t2) {
......
......@@ -27,6 +27,8 @@
#include <stdio.h>
#include <string.h>
#include <functional>
#include <string>
#include "host.h"
......@@ -43,6 +45,11 @@ class TRand {
void set_seed(uinT64 seed) {
seed_ = seed;
}
// Sets the seed using a hash of a string.
void set_seed(const std::string& str) {
std::hash<std::string> hasher;
set_seed(static_cast<uinT64>(hasher(str)));
}
// Returns an integer in the range 0 to MAX_INT32.
inT32 IntRand() {
......
......@@ -56,6 +56,17 @@ StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const {
return result;
}
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
void FullyConnected::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
if (training_ == TS_DISABLED) weights_.InitBackward(false);
training_ = TS_ENABLED;
} else {
training_ = state;
}
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int FullyConnected::InitWeights(float range, TRand* randomizer) {
......@@ -78,14 +89,14 @@ void FullyConnected::DebugWeights() {
// Writes to the given file. Returns false in case of error.
bool FullyConnected::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (!weights_.Serialize(training_, fp)) return false;
if (!weights_.Serialize(IsTraining(), fp)) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool FullyConnected::DeSerialize(bool swap, TFile* fp) {
if (!weights_.DeSerialize(training_, swap, fp)) return false;
if (!weights_.DeSerialize(IsTraining(), swap, fp)) return false;
return true;
}
......@@ -129,14 +140,14 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
}
ForwardTimeStep(d_input, i_input, t, temp_line);
output->WriteTimeStep(t, temp_line);
if (training() && type_ != NT_SOFTMAX) {
if (IsTraining() && type_ != NT_SOFTMAX) {
acts_.CopyTimeStepFrom(t, *output, t);
}
}
// Zero all the elements that are in the padding around images that allows
// multiple different-sized images to exist in a single array.
// acts_ is only used if this is not a softmax op.
if (training() && type_ != NT_SOFTMAX) {
if (IsTraining() && type_ != NT_SOFTMAX) {
acts_.ZeroInvalidElements();
}
output->ZeroInvalidElements();
......@@ -152,7 +163,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
const TransposedArray* input_transpose) {
// Softmax output is always float, so save the input type.
int_mode_ = input.int_mode();
if (training()) {
if (IsTraining()) {
acts_.Resize(input, no_);
// Source_ is a transposed copy of input. It isn't needed if provided.
external_source_ = input_transpose;
......@@ -163,7 +174,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
void FullyConnected::ForwardTimeStep(const double* d_input, const inT8* i_input,
int t, double* output_line) {
// input is copied to source_ line-by-line for cache coherency.
if (training() && external_source_ == NULL && d_input != NULL)
if (IsTraining() && external_source_ == NULL && d_input != NULL)
source_t_.WriteStrided(t, d_input);
if (d_input != NULL)
weights_.MatrixDotVector(d_input, output_line);
......
......@@ -61,6 +61,10 @@ class FullyConnected : public Network {
type_ = type;
}
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
virtual void SetEnableTraining(TrainingState state);
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
virtual int InitWeights(float range, TRand* randomizer);
......
......@@ -102,6 +102,23 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
return result;
}
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
void LSTM::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
if (training_ == TS_DISABLED) {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].InitBackward(false);
}
}
training_ = TS_ENABLED;
} else {
training_ = state;
}
if (softmax_ != NULL) softmax_->SetEnableTraining(state);
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int LSTM::InitWeights(float range, TRand* randomizer) {
......@@ -148,7 +165,7 @@ bool LSTM::Serialize(TFile* fp) const {
if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
if (!gate_weights_[w].Serialize(training_, fp)) return false;
if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
}
if (softmax_ != NULL && !softmax_->Serialize(fp)) return false;
return true;
......@@ -169,7 +186,7 @@ bool LSTM::DeSerialize(bool swap, TFile* fp) {
is_2d_ = false;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
if (!gate_weights_[w].DeSerialize(training_, swap, fp)) return false;
if (!gate_weights_[w].DeSerialize(IsTraining(), swap, fp)) return false;
if (w == CI) {
ns_ = gate_weights_[CI].NumOutputs();
is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
......@@ -322,7 +339,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
// Clip curr_state to a sane range.
ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
if (training_) {
if (IsTraining()) {
// Save the gate node values.
node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
......@@ -331,7 +348,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
}
FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
if (training_) state_.WriteTimeStep(t, curr_state);
if (IsTraining()) state_.WriteTimeStep(t, curr_state);
if (softmax_ != NULL) {
if (input.int_mode()) {
int_output->WriteTimeStep(0, curr_output);
......@@ -697,7 +714,7 @@ void LSTM::PrintDW() {
void LSTM::ResizeForward(const NetworkIO& input) {
source_.Resize(input, na_);
which_fg_.ResizeNoInit(input.Width(), ns_);
if (training_) {
if (IsTraining()) {
state_.ResizeFloat(input, ns_);
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
......
......@@ -69,6 +69,10 @@ class LSTM : public Network {
return spec;
}
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
virtual void SetEnableTraining(TrainingState state);
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
virtual int InitWeights(float range, TRand* randomizer);
......
......@@ -253,7 +253,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
float label_threshold, float* scale_factor,
NetworkIO* inputs, NetworkIO* outputs) {
// Maximum width of image to train on.
const int kMaxImageWidth = 2048;
const int kMaxImageWidth = 2560;
// This ensures consistent recognition results.
SetRandomSeed();
int min_width = network_->XScaleFactor();
......@@ -263,7 +263,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
tprintf("Line cannot be recognized!!\n");
return false;
}
if (network_->training() && pixGetWidth(pix) > kMaxImageWidth) {
if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix),
pixGetHeight(pix));
pixDestroy(&pix);
......
......@@ -134,8 +134,6 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) {
// Note: Call before InitNetwork!
void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset,
const STRING& script_dir, int train_flags) {
// Call before InitNetwork.
ASSERT_HOST(network_ == NULL);
EmptyConstructor();
training_flags_ = train_flags;
ccutil_.unicharset.CopyFrom(unicharset);
......@@ -150,8 +148,6 @@ void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset,
// Note: Call before InitNetwork!
void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset,
const UnicharCompress recoder) {
// Call before InitNetwork.
ASSERT_HOST(network_ == NULL);
EmptyConstructor();
int flags = TF_COMPRESS_UNICHARSET;
training_flags_ = static_cast<TrainingFlags>(flags);
......@@ -219,6 +215,30 @@ int LSTMTrainer::InitTensorFlowNetwork(const std::string& tf_proto) {
#endif
}
// Resets all the iteration counters for fine tuning or traininng a head,
// where we want the error reporting to reset.
void LSTMTrainer::InitIterations() {
sample_iteration_ = 0;
training_iteration_ = 0;
learning_iteration_ = 0;
prev_sample_iteration_ = 0;
best_error_rate_ = 100.0;
best_iteration_ = 0;
worst_error_rate_ = 0.0;
worst_iteration_ = 0;
stall_iteration_ = kMinStallIterations;
improvement_steps_ = kMinStallIterations;
perfect_delay_ = 0;
last_perfect_training_iteration_ = 0;
for (int i = 0; i < ET_COUNT; ++i) {
best_error_rates_[i] = 100.0;
worst_error_rates_[i] = 0.0;
error_buffers_[i].init_to_size(kRollingBufferSize_, 0.0);
error_rates_[i] = 100.0;
}
error_rate_of_last_saved_best_ = kMinStartedErrorRate;
}
// If the training sample is usable, grid searches for the optimal
// dict_ratio/cert_offset, and returns the results in a string of space-
// separated triplets of ratio,offset=worderr.
......@@ -460,8 +480,15 @@ bool LSTMTrainer::Serialize(TFile* fp) const {
// If swap is true, assumes a big/little-endian swap is needed.
bool LSTMTrainer::DeSerialize(bool swap, TFile* fp) {
if (!LSTMRecognizer::DeSerialize(swap, fp)) return false;
if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1)
return false;
if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) {
// Special case. If we successfully decoded the recognizer, but fail here
// then it means we were just given a recognizer, so issue a warning and
// allow it.
tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
learning_iteration_ = 0;
network_->SetEnableTraining(TS_RE_ENABLE);
return true;
}
if (fp->FRead(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) !=
1)
return false;
......@@ -629,7 +656,7 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
SaveTrainingDump(LIGHT, this, &orig_trainer);
for (int i = 0; i < num_layers; ++i) {
Network* layer = GetLayer(layers[i]);
num_weights[i] = layer->training() ? layer->num_weights() : 0;
num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
}
int iteration = sample_iteration();
for (int s = 0; s < num_samples; ++s) {
......@@ -773,7 +800,7 @@ Trainability LSTMTrainer::TrainOnLine(const ImageData* trainingdata,
training_iteration() % debug_interval_ == 0;
// Run backprop on the output.
NetworkIO bp_deltas;
if (network_->training() &&
if (network_->IsTraining() &&
(trainable != PERFECT ||
training_iteration() >
last_perfect_training_iteration_ + perfect_delay_)) {
......@@ -827,6 +854,7 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata,
return UNENCODABLE;
}
targets->Resize(*fwd_outputs, network_->NumOutputs());
double text_error = 100.0;
LossType loss_type = OutputLossType();
if (loss_type == LT_SOFTMAX) {
if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
......@@ -900,9 +928,9 @@ bool LSTMTrainer::ReadSizedTrainingDump(const char* data, int size) {
void LSTMTrainer::SaveRecognitionDump(GenericVector<char>* data) const {
TFile fp;
fp.OpenWrite(data);
network_->SetEnableTraining(false);
network_->SetEnableTraining(TS_TEMP_DISABLE);
ASSERT_HOST(LSTMRecognizer::Serialize(&fp));
network_->SetEnableTraining(true);
network_->SetEnableTraining(TS_RE_ENABLE);
}
// Reads and returns a previously saved recognizer from memory.
......@@ -942,25 +970,7 @@ void LSTMTrainer::EmptyConstructor() {
serialize_amount_ = FULL;
training_stage_ = 0;
num_training_stages_ = 2;
prev_sample_iteration_ = 0;
best_error_rate_ = 100.0;
best_iteration_ = 0;
worst_error_rate_ = 0.0;
worst_iteration_ = 0;
stall_iteration_ = kMinStallIterations;
learning_iteration_ = 0;
improvement_steps_ = kMinStallIterations;
perfect_delay_ = 0;
last_perfect_training_iteration_ = 0;
for (int i = 0; i < ET_COUNT; ++i) {
best_error_rates_[i] = 100.0;
worst_error_rates_[i] = 0.0;
error_buffers_[i].init_to_size(kRollingBufferSize_, 0.0);
error_rates_[i] = 100.0;
}
sample_iteration_ = 0;
training_iteration_ = 0;
error_rate_of_last_saved_best_ = kMinStartedErrorRate;
InitIterations();
}
// Sets the unicharset properties using the given script_dir as a source of
......
......@@ -127,6 +127,9 @@ class LSTMTrainer : public LSTMRecognizer {
// Returns the global step of TensorFlow graph or 0 if failed.
// Building a compatible TF graph: See tfnetwork.proto.
int InitTensorFlowNetwork(const std::string& tf_proto);
// Resets all the iteration counters for fine tuning or training a head,
// where we want the error reporting to reset.
void InitIterations();
// Accessors.
double ActivationError() const {
......
......@@ -69,23 +69,47 @@ char const* const Network::kTypeNames[NT_COUNT] = {
};
Network::Network()
: type_(NT_NONE), training_(true), needs_to_backprop_(true),
network_flags_(0), ni_(0), no_(0), num_weights_(0),
forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) {
}
: type_(NT_NONE),
training_(TS_ENABLED),
needs_to_backprop_(true),
network_flags_(0),
ni_(0),
no_(0),
num_weights_(0),
forward_win_(NULL),
backward_win_(NULL),
randomizer_(NULL) {}
Network::Network(NetworkType type, const STRING& name, int ni, int no)
: type_(type), training_(true), needs_to_backprop_(true),
network_flags_(0), ni_(ni), no_(no), num_weights_(0),
name_(name), forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) {
}
: type_(type),
training_(TS_ENABLED),
needs_to_backprop_(true),
network_flags_(0),
ni_(ni),
no_(no),
num_weights_(0),
name_(name),
forward_win_(NULL),
backward_win_(NULL),
randomizer_(NULL) {}
Network::~Network() {
}
// Ends training by setting the training_ flag to false. Serialize and
// DeSerialize will now only operate on the run-time data.
void Network::SetEnableTraining(bool state) {
training_ = state;
// Suspends/Enables/Permanently disables training by setting the training_
// flag. Serialize and DeSerialize only operate on the run-time data if state
// is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
// temporarily disable layers in state TS_ENABLED, allowing a trainer to
// serialize as if it were a recognizer.
// TS_RE_ENABLE will re-enable layers that were previously in any disabled
// state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
// TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
// recognizer can be converted back to a trainer.
void Network::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
training_ = TS_ENABLED;
} else {
training_ = state;
}
}
// Sets flags that control the action of the network. See NetworkFlags enum
......@@ -152,7 +176,7 @@ bool Network::DeSerialize(bool swap, TFile* fp) {
}
type_ = static_cast<NetworkType>(data);
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
training_ = data != 0;
training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
needs_to_backprop_ = data != 0;
if (fp->FRead(&network_flags_, sizeof(network_flags_), 1) != 1) return false;
......
......@@ -88,6 +88,16 @@ enum NetworkFlags {
NF_ADA_GRAD = 128, // Weight-specific learning rate.
};
// State of training and desired state used in SetEnableTraining.
enum TrainingState {
// Valid states of training_.
TS_DISABLED, // Disabled permanently.
TS_ENABLED, // Enabled for backprop and to write a training dump.
TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
// Valid only for SetEnableTraining.
TS_RE_ENABLE, // Re-Enable whatever the current state.
};
// Base class for network types. Not quite an abstract base class, but almost.
// Most of the time no isolated Network exists, except prior to
// deserialization.
......@@ -101,9 +111,7 @@ class Network {
NetworkType type() const {
return type_;
}
bool training() const {
return training_;
}
bool IsTraining() const { return training_ == TS_ENABLED; }
bool needs_to_backprop() const {
return needs_to_backprop_;
}
......@@ -142,9 +150,16 @@ class Network {
// multiple sub-networks that can have their own learning rate.
virtual bool IsPlumbingType() const { return false; }
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
virtual void SetEnableTraining(bool state);
// Suspends/Enables/Permanently disables training by setting the training_
// flag. Serialize and DeSerialize only operate on the run-time data if state
// is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
// temporarily disable layers in state TS_ENABLED, allowing a trainer to
// serialize as if it were a recognizer.
// TS_RE_ENABLE will re-enable layers that were previously in any disabled
// state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
// TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
// recognizer can be converted back to a trainer.
virtual void SetEnableTraining(TrainingState state);
// Sets flags that control the action of the network. See NetworkFlags enum
// for bit values.
......@@ -269,7 +284,7 @@ class Network {
protected:
NetworkType type_; // Type of the derived network class.
bool training_; // Are we currently training?
TrainingState training_; // Are we currently training?
bool needs_to_backprop_; // This network needs to output back_deltas.
inT32 network_flags_; // Behavior control flags in NetworkFlags.
inT32 ni_; // Number of input values.
......
......@@ -83,7 +83,7 @@ void Parallel::Forward(bool debug, const NetworkIO& input,
// Source for divided replicated.
NetworkScratch::IO source_part;
TransposedArray* src_transpose = NULL;
if (training() && type_ == NT_REPLICATED) {
if (IsTraining() && type_ == NT_REPLICATED) {
// Make a transposed copy of the input.
input.Transpose(&transposed_input_);
src_transpose = &transposed_input_;
......
......@@ -31,7 +31,7 @@ Plumbing::~Plumbing() {
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
void Plumbing::SetEnableTraining(bool state) {
void Plumbing::SetEnableTraining(TrainingState state) {
Network::SetEnableTraining(state);
for (int i = 0; i < stack_.size(); ++i)
stack_[i]->SetEnableTraining(state);
......@@ -91,13 +91,17 @@ void Plumbing::AddToStack(Network* network) {
// Sets needs_to_backprop_ to needs_backprop and calls on sub-network
// according to needs_backprop || any weights in this network.
bool Plumbing::SetupNeedsBackprop(bool needs_backprop) {
needs_to_backprop_ = needs_backprop;
bool retval = needs_backprop;
for (int i = 0; i < stack_.size(); ++i) {
if (stack_[i]->SetupNeedsBackprop(needs_backprop))
retval = true;
if (IsTraining()) {
needs_to_backprop_ = needs_backprop;
bool retval = needs_backprop;
for (int i = 0; i < stack_.size(); ++i) {
if (stack_[i]->SetupNeedsBackprop(needs_backprop)) retval = true;
}
return retval;
}
return retval;
// Frozen networks don't do backprop.
needs_to_backprop_ = false;
return false;
}
// Returns an integer reduction factor that the network applies to the
......@@ -212,8 +216,9 @@ void Plumbing::Update(float learning_rate, float momentum, int num_samples) {
else
learning_rates_.push_back(learning_rate);
}
if (stack_[i]->training())
if (stack_[i]->IsTraining()) {
stack_[i]->Update(learning_rate, momentum, num_samples);
}
}
}
......
......@@ -45,7 +45,7 @@ class Plumbing : public Network {
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
virtual void SetEnableTraining(bool state);
virtual void SetEnableTraining(TrainingState state);
// Sets flags that control the action of the network. See NetworkFlags enum
// for bit values.
......
......@@ -116,7 +116,7 @@ void Series::Forward(bool debug, const NetworkIO& input,
bool Series::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
if (!training()) return false;
if (!IsTraining()) return false;
int stack_size = stack_.size();
ASSERT_HOST(stack_size > 1);
// Revolving intermediate buffers.
......@@ -124,16 +124,16 @@ bool Series::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch::IO buffer2(fwd_deltas, scratch);
// Run each network in reverse order, giving the back_deltas output of n as
// the fwd_deltas input to n-1, with the 0 network providing the real output.
if (!stack_.back()->training() ||
if (!stack_.back()->IsTraining() ||
!stack_.back()->Backward(debug, fwd_deltas, scratch, buffer1))
return false;
for (int i = stack_size - 2; i >= 0; i -= 2) {
if (!stack_[i]->training() ||
if (!stack_[i]->IsTraining() ||
!stack_[i]->Backward(debug, *buffer1, scratch,
i > 0 ? buffer2 : back_deltas))
return false;
if (i == 0) return needs_to_backprop_;
if (!stack_[i - 1]->training() ||
if (!stack_[i - 1]->IsTraining() ||
!stack_[i - 1]->Backward(debug, *buffer2, scratch,
i > 1 ? buffer1 : back_deltas))
return false;
......
......@@ -69,8 +69,8 @@ class StrideMap {
bool IsValid() const;
// Returns true if the index of the given dimension is the last.
bool IsLast(FlexDimensions dimension) const;
// Given that the dimensions up to and including dim-1 are valid, returns the
// maximum index for dimension dim.
// Given that the dimensions up to and including dim-1 are valid, returns
// the maximum index for dimension dim.
int MaxIndexOfDim(FlexDimensions dim) const;
// Adds the given offset to the given dimension. Returns true if the result
// makes a valid index.
......
......@@ -98,8 +98,6 @@ void TransposedArray::Transpose(const GENERIC_2D_ARRAY<double>& input) {
int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
float weight_range, TRand* randomizer) {
int_mode_ = false;
use_ada_grad_ = ada_grad;
if (use_ada_grad_) dw_sq_sum_.Resize(no, ni, 0.0);
wf_.Resize(no, ni, 0.0);
if (randomizer != NULL) {
for (int i = 0; i < no; ++i) {
......@@ -108,7 +106,7 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
}
}
}
InitBackward();
InitBackward(ada_grad);
return ni * no;
}
......@@ -144,12 +142,14 @@ void WeightMatrix::ConvertToInt() {
// Allocates any needed memory for running Backward, and zeroes the deltas,
// thus eliminating any existing momentum.
void WeightMatrix::InitBackward() {
void WeightMatrix::InitBackward(bool ada_grad) {
int no = int_mode_ ? wi_.dim1() : wf_.dim1();
int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
use_ada_grad_ = ada_grad;
dw_.Resize(no, ni, 0.0);
updates_.Resize(no, ni, 0.0);
wf_t_.Transpose(wf_);
if (use_ada_grad_) dw_sq_sum_.Resize(no, ni, 0.0);
}
// Flag on mode to indicate that this weightmatrix uses inT8.
......@@ -193,7 +193,7 @@ bool WeightMatrix::DeSerialize(bool training, bool swap, TFile* fp) {
} else {
if (!wf_.DeSerialize(swap, fp)) return false;
if (training) {
InitBackward();
InitBackward(use_ada_grad_);
if (!updates_.DeSerialize(swap, fp)) return false;
if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(swap, fp)) return false;
}
......@@ -216,7 +216,7 @@ bool WeightMatrix::DeSerializeOld(bool training, bool swap, TFile* fp) {
FloatToDouble(float_array, &wf_);
}
if (training) {
InitBackward();
InitBackward(use_ada_grad_);
if (!float_array.DeSerialize(swap, fp)) return false;
FloatToDouble(float_array, &updates_);
// Errs was only used in int training, which is now dead.
......
......@@ -92,7 +92,7 @@ class WeightMatrix {
// Allocates any needed memory for running Backward, and zeroes the deltas,
// thus eliminating any existing momentum.
void InitBackward();
void InitBackward(bool ada_grad);
// Writes to the given file. Returns false in case of error.
bool Serialize(bool training, TFile* fp) const;
......
......@@ -27,7 +27,7 @@ endif
noinst_HEADERS = \
boxchar.h commandlineflags.h commontraining.h degradeimage.h \
fileio.h icuerrorcode.h ligature_table.h normstrngs.h \
fileio.h icuerrorcode.h ligature_table.h lstmtester.h normstrngs.h \
mergenf.h pango_font_info.h stringrenderer.h \
tessopt.h tlog.h unicharset_training_utils.h util.h
......@@ -39,14 +39,14 @@ libtesseract_training_la_LIBADD = \
libtesseract_training_la_SOURCES = \
boxchar.cpp commandlineflags.cpp commontraining.cpp degradeimage.cpp \
fileio.cpp ligature_table.cpp normstrngs.cpp pango_font_info.cpp \
fileio.cpp ligature_table.cpp lstmtester.cpp normstrngs.cpp pango_font_info.cpp \
stringrenderer.cpp tlog.cpp unicharset_training_utils.cpp
libtesseract_tessopt_la_SOURCES = \
tessopt.cpp
bin_PROGRAMS = ambiguous_words classifier_tester cntraining combine_tessdata \
dawg2wordlist lstmtraining mftraining set_unicharset_properties shapeclustering \
dawg2wordlist lstmeval lstmtraining mftraining set_unicharset_properties shapeclustering \
text2image unicharset_extractor wordlist2dawg
ambiguous_words_SOURCES = ambiguous_words.cpp
......@@ -163,6 +163,33 @@ dawg2wordlist_LDADD += \
../api/libtesseract.la
endif
lstmeval_SOURCES = lstmeval.cpp
#lstmeval_LDFLAGS = -static
lstmeval_LDADD = \
libtesseract_training.la \
libtesseract_tessopt.la \
$(libicu)
if USING_MULTIPLELIBS
lstmeval_LDADD += \
../textord/libtesseract_textord.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
../ccmain/libtesseract_main.la \
../cube/libtesseract_cube.la \
../neural_networks/runtime/libtesseract_neural.la \
../wordrec/libtesseract_wordrec.la \
../ccutil/libtesseract_ccutil.la
else
lstmeval_LDADD += \
../api/libtesseract.la
endif
lstmtraining_SOURCES = lstmtraining.cpp
#lstmtraining_LDFLAGS = -static
lstmtraining_LDADD = \
......
///////////////////////////////////////////////////////////////////////
// File: lstmeval.cpp
// Description: Evaluation program for LSTM-based networks.
// Author: Ray Smith
// Created: Wed Nov 23 12:20:06 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef USE_STD_NAMESPACE
#include "base/commandlineflags.h"
#endif
#include "commontraining.h"
#include "genericvector.h"
#include "lstmtester.h"
#include "strngs.h"
#include "tprintf.h"
STRING_PARAM_FLAG(model, "", "Name of model file (training or recognition)");
STRING_PARAM_FLAG(eval_listfile, "",
"File listing sample files in lstmf training format.");
INT_PARAM_FLAG(max_image_MB, 2000, "Max memory to use for images.");
int main(int argc, char **argv) {
ParseArguments(&argc, &argv);
if (FLAGS_model.empty()) {
tprintf("Must provide a --model!\n");
return 1;
}
if (FLAGS_eval_listfile.empty()) {
tprintf("Must provide a --eval_listfile!\n");
return 1;
}
GenericVector<char> model_data;
if (!tesseract::LoadDataFromFile(FLAGS_model.c_str(), &model_data)) {
tprintf("Failed to load model from: %s\n", FLAGS_eval_listfile.c_str());
return 1;
}
tesseract::LSTMTester tester(static_cast<inT64>(FLAGS_max_image_MB) *
1048576);
if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
tprintf("Failed to load eval data from: %s\n", FLAGS_eval_listfile.c_str());
return 1;
}
double errs = 0.0;
STRING result = tester.RunEvalSync(0, &errs, model_data, 0);
tprintf("%s\n", result.string());
return 0;
} /* main */
///////////////////////////////////////////////////////////////////////
// File: lstmtester.cpp
// Description: Top-level line evaluation class for LSTM-based networks.
// Author: Ray Smith
// Created: Wed Nov 23 11:18:06 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "lstmtester.h"
#include "genericvector.h"
namespace tesseract {
LSTMTester::LSTMTester(inT64 max_memory)
: test_data_(max_memory), total_pages_(0), async_running_(false) {}
// Loads a set of lstmf files that were created using the lstm.train config to
// tesseract into memory ready for testing. Returns false if nothing was
// loaded. The arg is a filename of a file that lists the filenames.
bool LSTMTester::LoadAllEvalData(const STRING& filenames_file) {
GenericVector<STRING> filenames;
if (!LoadFileLinesToStrings(filenames_file, &filenames)) {
tprintf("Failed to load list of eval filenames from %s\n",
filenames_file.string());
return false;
}
return LoadAllEvalData(filenames);
}
// Loads a set of lstmf files that were created using the lstm.train config to
// tesseract into memory ready for testing. Returns false if nothing was
// loaded.
bool LSTMTester::LoadAllEvalData(const GenericVector<STRING>& filenames) {
test_data_.Clear();
bool result =
test_data_.LoadDocuments(filenames, "eng", CS_SEQUENTIAL, nullptr);
total_pages_ = test_data_.TotalPages();
return result;
}
// Runs an evaluation asynchronously on the stored data and returns a string
// describing the results of the previous test.
STRING LSTMTester::RunEvalAsync(int iteration, const double* training_errors,
const GenericVector<char>& model_data,
int training_stage) {
STRING result;
if (total_pages_ == 0) {
result.add_str_int("No test data at iteration", iteration);
return result;
}
if (!LockIfNotRunning()) {
result.add_str_int("Previous test incomplete, skipping test at iteration",
iteration);
return result;
}
// Save the args.
STRING prev_result = test_result_;
test_result_ = "";
if (training_errors != nullptr) {
test_iteration_ = iteration;
test_training_errors_ = training_errors;
test_model_data_ = model_data;
test_training_stage_ = training_stage;
SVSync::StartThread(&LSTMTester::ThreadFunc, this);
} else {
UnlockRunning();
}
return prev_result;
}
// Runs an evaluation synchronously on the stored data and returns a string
// describing the results.
STRING LSTMTester::RunEvalSync(int iteration, const double* training_errors,
const GenericVector<char>& model_data,
int training_stage) {
LSTMTrainer trainer;
if (!trainer.ReadTrainingDump(model_data, &trainer)) {
return "Deserialize failed";
}
int eval_iteration = 0;
double char_error = 0.0;
double word_error = 0.0;
int error_count = 0;
while (error_count < total_pages_) {
const ImageData* trainingdata = test_data_.GetPageBySerial(eval_iteration);
trainer.SetIteration(++eval_iteration);
NetworkIO fwd_outputs, targets;
if (trainer.PrepareForBackward(trainingdata, &fwd_outputs, &targets) !=
UNENCODABLE) {
char_error += trainer.NewSingleError(tesseract::ET_CHAR_ERROR);
word_error += trainer.NewSingleError(tesseract::ET_WORD_RECERR);
++error_count;
}
}
char_error *= 100.0 / total_pages_;
word_error *= 100.0 / total_pages_;
STRING result;
result.add_str_int("At iteration ", iteration);
result.add_str_int(", stage ", training_stage);
result.add_str_double(", Eval Char error rate=", char_error);
result.add_str_double(", Word error rate=", word_error);
return result;
}
// Static helper thread function for RunEvalAsync, with a specific signature
// required by SVSync::StartThread. Actually a member function pretending to
// be static, its arg is a this pointer that it will cast back to LSTMTester*
// to call RunEvalSync using the stored args that RunEvalAsync saves in *this.
// LockIfNotRunning must have returned true before calling ThreadFunc, and
// it will call UnlockRunning to release the lock after RunEvalSync completes.
/* static */
void* LSTMTester::ThreadFunc(void* lstmtester_void) {
LSTMTester* lstmtester = reinterpret_cast<LSTMTester*>(lstmtester_void);
lstmtester->test_result_ = lstmtester->RunEvalSync(
lstmtester->test_iteration_, lstmtester->test_training_errors_,
lstmtester->test_model_data_, lstmtester->test_training_stage_);
lstmtester->UnlockRunning();
return lstmtester_void;
}
// Returns true if there is currently nothing running, and takes the lock
// if there is nothing running.
bool LSTMTester::LockIfNotRunning() {
SVAutoLock lock(&running_mutex_);
if (async_running_) return false;
async_running_ = true;
return true;
}
// Releases the running lock.
void LSTMTester::UnlockRunning() {
SVAutoLock lock(&running_mutex_);
async_running_ = false;
}
} // namespace tesseract
///////////////////////////////////////////////////////////////////////
// File: lstmtester.h
// Description: Top-level line evaluation class for LSTM-based networks.
// Author: Ray Smith
// Created: Wed Nov 23 11:05:06 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_TRAINING_LSTMTESTER_H_
#define TESSERACT_TRAINING_LSTMTESTER_H_
#include "genericvector.h"
#include "lstmtrainer.h"
#include "strngs.h"
#include "svutil.h"
namespace tesseract {
class LSTMTester {
public:
LSTMTester(inT64 max_memory);
// Loads a set of lstmf files that were created using the lstm.train config to
// tesseract into memory ready for testing. Returns false if nothing was
// loaded. The arg is a filename of a file that lists the filenames, with one
// name per line. Conveniently, tesstrain.sh generates such a file, along
// with the files themselves.
bool LoadAllEvalData(const STRING& filenames_file);
// Loads a set of lstmf files that were created using the lstm.train config to
// tesseract into memory ready for testing. Returns false if nothing was
// loaded.
bool LoadAllEvalData(const GenericVector<STRING>& filenames);
// Runs an evaluation asynchronously on the stored eval data and returns a
// string describing the results of the previous test. Args match TestCallback
// declared in lstmtrainer.h:
// iteration: Current learning iteration number.
// training_errors: If not null, is an array of size ET_COUNT, indexed by
// the ErrorTypes enum and indicates the current errors measured by the
// trainer, and this is a serious request to run an evaluation. If null,
// then the caller is just polling for the results of the previous eval.
// model_data: is the model to evaluate, which should be a serialized
// LSTMTrainer.
// training_stage: an arbitrary number on the progress of training.
STRING RunEvalAsync(int iteration, const double* training_errors,
const GenericVector<char>& model_data,
int training_stage);
// Runs an evaluation synchronously on the stored eval data and returns a
// string describing the results. Args as RunEvalAsync.
STRING RunEvalSync(int iteration, const double* training_errors,
const GenericVector<char>& model_data, int training_stage);
private:
// Static helper thread function for RunEvalAsync, with a specific signature
// required by SVSync::StartThread. Actually a member function pretending to
// be static, its arg is a this pointer that it will cast back to LSTMTester*
// to call RunEvalSync using the stored args that RunEvalAsync saves in *this.
// LockIfNotRunning must have returned true before calling ThreadFunc, and
// it will call UnlockRunning to release the lock after RunEvalSync completes.
static void* ThreadFunc(void* lstmtester_void);
// Returns true if there is currently nothing running, and takes the lock
// if there is nothing running.
bool LockIfNotRunning();
// Releases the running lock.
void UnlockRunning();
// The data to test with.
DocumentCache test_data_;
int total_pages_;
// Flag that indicates an asynchronous test is currently running.
// Protected by running_mutex_.
bool async_running_;
SVMutex running_mutex_;
// Stored copies of the args for use while running asynchronously.
int test_iteration_;
const double* test_training_errors_;
GenericVector<char> test_model_data_;
int test_training_stage_;
STRING test_result_;
};
} // namespace tesseract
#endif // TESSERACT_TRAINING_LSTMTESTER_H_
......@@ -20,6 +20,7 @@
#include "base/commandlineflags.h"
#endif
#include "commontraining.h"
#include "lstmtester.h"
#include "lstmtrainer.h"
#include "params.h"
#include "strngs.h"
......@@ -27,8 +28,8 @@
#include "unicharset_training_utils.h"
INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.");
STRING_PARAM_FLAG(net_spec, "[I1,48Lt1,100O]", "Network specification");
INT_PARAM_FLAG(train_mode, 64, "Controls gross training behavior.");
STRING_PARAM_FLAG(net_spec, "", "Network specification");
INT_PARAM_FLAG(train_mode, 80, "Controls gross training behavior.");
INT_PARAM_FLAG(net_mode, 192, "Controls network behavior.");
INT_PARAM_FLAG(perfect_sample_delay, 4,
"How many imperfect samples between perfect ones.");
......@@ -42,6 +43,10 @@ STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
STRING_PARAM_FLAG(script_dir, "",
"Required to set unicharset properties or"
" use unicharset compression.");
STRING_PARAM_FLAG(train_listfile, "",
"File listing training files in lstmf training format.");
STRING_PARAM_FLAG(eval_listfile, "",
"File listing eval files in lstmf training format.");
BOOL_PARAM_FLAG(stop_training, false,
"Just convert the training model to a runtime model.");
INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to"
......@@ -106,9 +111,16 @@ int main(int argc, char **argv) {
}
// Get the list of files to process.
if (FLAGS_train_listfile.empty()) {
tprintf("Must supply a list of training filenames! --train_listfile\n");
return 1;
}
GenericVector<STRING> filenames;
for (int arg = 1; arg < argc; ++arg) {
filenames.push_back(STRING(argv[arg]));
if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(),
&filenames)) {
tprintf("Failed to load list of training filenames from %s\n",
FLAGS_train_listfile.c_str());
return 1;
}
UNICHARSET unicharset;
......@@ -125,6 +137,7 @@ int main(int argc, char **argv) {
return 1;
}
tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
trainer.InitIterations();
}
if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
// We need a unicharset to start from scratch or append.
......@@ -164,6 +177,18 @@ int main(int argc, char **argv) {
char* best_model_dump = NULL;
size_t best_model_size = 0;
STRING best_model_name;
tesseract::LSTMTester tester(static_cast<inT64>(FLAGS_max_image_MB) *
1048576);
tesseract::TestCallback tester_callback = nullptr;
if (!FLAGS_eval_listfile.empty()) {
if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
tprintf("Failed to load eval data from: %s\n",
FLAGS_eval_listfile.c_str());
return 1;
}
tester_callback =
NewPermanentTessCallback(&tester, &tesseract::LSTMTester::RunEvalAsync);
}
do {
// Train a few.
int iteration = trainer.training_iteration();
......@@ -173,11 +198,12 @@ int main(int argc, char **argv) {
trainer.TrainOnLine(&trainer, false);
}
STRING log_str;
trainer.MaintainCheckpoints(NULL, &log_str);
trainer.MaintainCheckpoints(tester_callback, &log_str);
tprintf("%s\n", log_str.string());
} while (trainer.best_error_rate() > FLAGS_target_error_rate &&
(trainer.training_iteration() < FLAGS_max_iterations ||
FLAGS_max_iterations == 0));
delete tester_callback;
tprintf("Finished! Error rate = %g\n", trainer.best_error_rate());
return 0;
} /* main */
......
......@@ -23,6 +23,7 @@
# --langdata_dir DATADIR # Path to tesseract/training/langdata directory.
# --output_dir OUTPUTDIR # Location of output traineddata file.
# --overwrite # Safe to overwrite files in output_dir.
# --linedata_only # Only generate training data for lstmtraining.
# --run_shape_clustering # Run shape clustering (use for Indic langs).
# --exposures EXPOSURES # A list of exposure levels to use (e.g. "-1 0 1").
#
......@@ -60,13 +61,18 @@ initialize_fontconfig
phase_I_generate_image 8
phase_UP_generate_unicharset
phase_D_generate_dawg
phase_E_extract_features "box.train" 8
phase_C_cluster_prototypes "${TRAINING_DIR}/${LANG_CODE}.normproto"
if [[ "${ENABLE_SHAPE_CLUSTERING}" == "y" ]]; then
phase_S_cluster_shapes
if (( ${LINEDATA} )); then
phase_E_extract_features "lstm.train" 8 "lstmf"
make__lstmdata
else
phase_E_extract_features "box.train" 8 "tr"
phase_C_cluster_prototypes "${TRAINING_DIR}/${LANG_CODE}.normproto"
if [[ "${ENABLE_SHAPE_CLUSTERING}" == "y" ]]; then
phase_S_cluster_shapes
fi
phase_M_cluster_microfeatures
phase_B_generate_ambiguities
make__traineddata
fi
phase_M_cluster_microfeatures
phase_B_generate_ambiguities
make__traineddata
tlog "\nCompleted training for language '${LANG_CODE}'\n"
......@@ -23,6 +23,7 @@ else
fi
OUTPUT_DIR="/tmp/tesstrain/tessdata"
OVERWRITE=0
LINEDATA=0
RUN_SHAPE_CLUSTERING=0
EXTRACT_FONT_PROPERTIES=1
WORKSPACE_DIR=`mktemp -d`
......@@ -90,8 +91,8 @@ parse_flags() {
--)
break;;
--fontlist)
fn=0
FONTS=""
fn=0
FONTS=""
while test $j -lt ${#ARGV[@]}; do
test -z "${ARGV[$j]}" && break
test `echo ${ARGV[$j]} | cut -c -2` = "--" && break
......@@ -124,6 +125,8 @@ parse_flags() {
i=$j ;;
--overwrite)
OVERWRITE=1 ;;
--linedata_only)
LINEDATA=1 ;;
--extract_font_properties)
EXTRACT_FONT_PROPERTIES=1 ;;
--noextract_font_properties)
......@@ -368,10 +371,11 @@ phase_D_generate_dawg() {
phase_E_extract_features() {
local box_config=$1
local par_factor=$2
local ext=$3
if [[ -z ${par_factor} || ${par_factor} -le 0 ]]; then
par_factor=1
fi
tlog "\n=== Phase E: Extracting features ==="
tlog "\n=== Phase E: Generating ${ext} files ==="
local img_files=""
for exposure in ${EXPOSURES}; do
......@@ -401,7 +405,7 @@ phase_E_extract_features() {
export TESSDATA_PREFIX=${OLD_TESSDATA_PREFIX}
# Check that all the output files were produced.
for img_file in ${img_files}; do
check_file_readable ${img_file%.*}.tr
check_file_readable "${img_file%.*}.${ext}"
done
}
......@@ -484,6 +488,39 @@ phase_B_generate_ambiguities() {
# TODO: Add support for generating ambiguities automatically.
}
make__lstmdata() {
tlog "\n=== Constructing LSTM training data ==="
local lang_prefix=${LANGDATA_ROOT}/${LANG_CODE}/${LANG_CODE}
if [[ ! -d ${OUTPUT_DIR} ]]; then
tlog "Creating new directory ${OUTPUT_DIR}"
mkdir -p ${OUTPUT_DIR}
fi
# Copy available files for this language from the langdata dir.
if [[ -r ${lang_prefix}.config ]]; then
tlog "Copying ${lang_prefix}.config to ${OUTPUT_DIR}"
cp ${lang_prefix}.config ${OUTPUT_DIR}
chmod u+w ${OUTPUT_DIR}/${LANG_CODE}.config
fi
if [[ -r "${TRAINING_DIR}/${LANG_CODE}.unicharset" ]]; then
tlog "Moving ${TRAINING_DIR}/${LANG_CODE}.unicharset to ${OUTPUT_DIR}"
mv "${TRAINING_DIR}/${LANG_CODE}.unicharset" "${OUTPUT_DIR}"
fi
for ext in number-dawg punc-dawg word-dawg; do
local src="${TRAINING_DIR}/${LANG_CODE}.${ext}"
if [[ -r "${src}" ]]; then
dest="${OUTPUT_DIR}/${LANG_CODE}.lstm-${ext}"
tlog "Moving ${src} to ${dest}"
mv "${src}" "${dest}"
fi
done
for f in "${TRAINING_DIR}/${LANG_CODE}".*.lstmf; do
tlog "Moving ${f} to ${OUTPUT_DIR}"
mv "${f}" "${OUTPUT_DIR}"
done
local lstm_list="${OUTPUT_DIR}/${LANG_CODE}.training_files.txt"
ls -1 "${OUTPUT_DIR}"/*.lstmf > "${lstm_list}"
}
make__traineddata() {
tlog "\n=== Making final traineddata file ==="
......
......@@ -26,8 +26,8 @@
#ifdef _WIN32
#ifndef __GNUC__
#include "platform.h"
#include <windows.h>
#include "platform.h"
#if defined(_MSC_VER) && _MSC_VER < 1900
#define snprintf _snprintf
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册