提交 4e9665de 编写于 作者: R Ray Smith

Added ADAM optimizer, unless git screwed it up, cos there is no diff

上级 2633fef0
AM_CPPFLAGS += -I$(top_srcdir)/ccutil -I$(top_srcdir)/viewer AM_CPPFLAGS += -I$(top_srcdir)/ccutil -I$(top_srcdir)/viewer -DUSE_STD_NAMESPACE
AUTOMAKE_OPTIONS = subdir-objects AUTOMAKE_OPTIONS = subdir-objects
SUBDIRS = SUBDIRS =
AM_CXXFLAGS = AM_CXXFLAGS =
......
...@@ -37,6 +37,9 @@ SIMDDetect SIMDDetect::detector; ...@@ -37,6 +37,9 @@ SIMDDetect SIMDDetect::detector;
// If true, then AVX has been detected. // If true, then AVX has been detected.
bool SIMDDetect::avx_available_; bool SIMDDetect::avx_available_;
bool SIMDDetect::avx2_available_;
bool SIMDDetect::avx512F_available_;
bool SIMDDetect::avx512BW_available_;
// If true, then SSe4.1 has been detected. // If true, then SSe4.1 has been detected.
bool SIMDDetect::sse_available_; bool SIMDDetect::sse_available_;
...@@ -50,8 +53,19 @@ SIMDDetect::SIMDDetect() { ...@@ -50,8 +53,19 @@ SIMDDetect::SIMDDetect() {
#if defined(__GNUC__) #if defined(__GNUC__)
unsigned int eax, ebx, ecx, edx; unsigned int eax, ebx, ecx, edx;
if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) { if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) {
// Note that these tests all use hex because the older compilers don't have
// the newer flags.
sse_available_ = (ecx & 0x00080000) != 0; sse_available_ = (ecx & 0x00080000) != 0;
avx_available_ = (ecx & 0x10000000) != 0; avx_available_ = (ecx & 0x10000000) != 0;
if (avx_available_) {
// There is supposed to be a __get_cpuid_count function, but this is all
// there is in my cpuid.h. It is a macro for an asm statement and cannot
// be used inside an if.
__cpuid_count(7, 0, eax, ebx, ecx, edx);
avx2_available_ = (ebx & 0x00000020) != 0;
avx512F_available_ = (ebx & 0x00010000) != 0;
avx512BW_available_ = (ebx & 0x40000000) != 0;
}
} }
#elif defined(_WIN32) #elif defined(_WIN32)
int cpuInfo[4]; int cpuInfo[4];
......
...@@ -24,6 +24,16 @@ class SIMDDetect { ...@@ -24,6 +24,16 @@ class SIMDDetect {
public: public:
// Returns true if AVX is available on this system. // Returns true if AVX is available on this system.
static inline bool IsAVXAvailable() { return detector.avx_available_; } static inline bool IsAVXAvailable() { return detector.avx_available_; }
// Returns true if AVX2 (integer support) is available on this system.
static inline bool IsAVX2Available() { return detector.avx2_available_; }
// Returns true if AVX512 Foundation (float) is available on this system.
static inline bool IsAVX512FAvailable() {
return detector.avx512F_available_;
}
// Returns true if AVX512 integer is available on this system.
static inline bool IsAVX512BWAvailable() {
return detector.avx512BW_available_;
}
// Returns true if SSE4.1 is available on this system. // Returns true if SSE4.1 is available on this system.
static inline bool IsSSEAvailable() { return detector.sse_available_; } static inline bool IsSSEAvailable() { return detector.sse_available_; }
...@@ -36,6 +46,9 @@ class SIMDDetect { ...@@ -36,6 +46,9 @@ class SIMDDetect {
static SIMDDetect detector; static SIMDDetect detector;
// If true, then AVX has been detected. // If true, then AVX has been detected.
static TESS_API bool avx_available_; static TESS_API bool avx_available_;
static TESS_API bool avx2_available_;
static TESS_API bool avx512F_available_;
static TESS_API bool avx512BW_available_;
// If true, then SSe4.1 has been detected. // If true, then SSe4.1 has been detected.
static TESS_API bool sse_available_; static TESS_API bool sse_available_;
}; };
...@@ -360,19 +360,22 @@ class GENERIC_2D_ARRAY { ...@@ -360,19 +360,22 @@ class GENERIC_2D_ARRAY {
} }
// Accumulates the element-wise sums of squares of src into *this. // Accumulates the element-wise sums of squares of src into *this.
void SumSquares(const GENERIC_2D_ARRAY<T>& src) { void SumSquares(const GENERIC_2D_ARRAY<T>& src, T decay_factor) {
T update_factor = 1.0 - decay_factor;
int size = num_elements(); int size = num_elements();
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
array_[i] += src.array_[i] * src.array_[i]; array_[i] = array_[i] * decay_factor +
update_factor * src.array_[i] * src.array_[i];
} }
} }
// Scales each element using the ada-grad algorithm, ie array_[i] by // Scales each element using the adam algorithm, ie array_[i] by
// sqrt(num_samples/max(1,sqsum[i])). // sqrt(sqsum[i] + epsilon)).
void AdaGradScaling(const GENERIC_2D_ARRAY<T>& sqsum, int num_samples) { void AdamUpdate(const GENERIC_2D_ARRAY<T>& sum,
const GENERIC_2D_ARRAY<T>& sqsum, T epsilon) {
int size = num_elements(); int size = num_elements();
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
array_[i] *= sqrt(num_samples / MAX(1.0, sqsum.array_[i])); array_[i] += sum.array_[i] / (sqrt(sqsum.array_[i]) + epsilon);
} }
} }
......
...@@ -112,7 +112,7 @@ bool Convolve::Backward(bool debug, const NetworkIO& fwd_deltas, ...@@ -112,7 +112,7 @@ bool Convolve::Backward(bool debug, const NetworkIO& fwd_deltas,
} }
} }
} while (src_index.Increment()); } while (src_index.Increment());
back_deltas->CopyWithNormalization(*delta_sum, fwd_deltas); back_deltas->CopyAll(*delta_sum);
return true; return true;
} }
......
...@@ -79,11 +79,24 @@ void FullyConnected::SetEnableTraining(TrainingState state) { ...@@ -79,11 +79,24 @@ void FullyConnected::SetEnableTraining(TrainingState state) {
// scale `range` picked according to the random number generator `randomizer`. // scale `range` picked according to the random number generator `randomizer`.
int FullyConnected::InitWeights(float range, TRand* randomizer) { int FullyConnected::InitWeights(float range, TRand* randomizer) {
Network::SetRandomizer(randomizer); Network::SetRandomizer(randomizer);
num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADA_GRAD), num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADAM),
range, randomizer); range, randomizer);
return num_weights_; return num_weights_;
} }
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int FullyConnected::RemapOutputs(int old_no, const std::vector<int>& code_map) {
if (type_ == NT_SOFTMAX && no_ == old_no) {
num_weights_ = weights_.RemapOutputs(code_map);
no_ = code_map.size();
}
return num_weights_;
}
// Converts a float network to an int network. // Converts a float network to an int network.
void FullyConnected::ConvertToInt() { void FullyConnected::ConvertToInt() {
weights_.ConvertToInt(); weights_.ConvertToInt();
...@@ -240,7 +253,6 @@ bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas, ...@@ -240,7 +253,6 @@ bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
FinishBackward(*errors_t.get()); FinishBackward(*errors_t.get());
if (needs_to_backprop_) { if (needs_to_backprop_) {
back_deltas->ZeroInvalidElements(); back_deltas->ZeroInvalidElements();
back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
#if DEBUG_DETAIL > 0 #if DEBUG_DETAIL > 0
tprintf("F Backprop:%s\n", name_.string()); tprintf("F Backprop:%s\n", name_.string());
back_deltas->Print(10); back_deltas->Print(10);
...@@ -281,12 +293,11 @@ void FullyConnected::FinishBackward(const TransposedArray& errors_t) { ...@@ -281,12 +293,11 @@ void FullyConnected::FinishBackward(const TransposedArray& errors_t) {
weights_.SumOuterTransposed(errors_t, *external_source_, true); weights_.SumOuterTransposed(errors_t, *external_source_, true);
} }
// Updates the weights using the given learning rate and momentum. // Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is the quotient to be used in the adagrad computation iff // num_samples is used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true.
void FullyConnected::Update(float learning_rate, float momentum, void FullyConnected::Update(float learning_rate, float momentum,
int num_samples) { float adam_beta, int num_samples) {
weights_.Update(learning_rate, momentum, num_samples); weights_.Update(learning_rate, momentum, adam_beta, num_samples);
} }
// Sums the products of weight updates in *this and other, splitting into // Sums the products of weight updates in *this and other, splitting into
......
...@@ -68,6 +68,12 @@ class FullyConnected : public Network { ...@@ -68,6 +68,12 @@ class FullyConnected : public Network {
// Sets up the network for training. Initializes weights using weights of // Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`. // scale `range` picked according to the random number generator `randomizer`.
virtual int InitWeights(float range, TRand* randomizer); virtual int InitWeights(float range, TRand* randomizer);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int RemapOutputs(int old_no, const std::vector<int>& code_map) override;
// Converts a float network to an int network. // Converts a float network to an int network.
virtual void ConvertToInt(); virtual void ConvertToInt();
...@@ -101,10 +107,10 @@ class FullyConnected : public Network { ...@@ -101,10 +107,10 @@ class FullyConnected : public Network {
TransposedArray* errors_t, double* backprop); TransposedArray* errors_t, double* backprop);
void FinishBackward(const TransposedArray& errors_t); void FinishBackward(const TransposedArray& errors_t);
// Updates the weights using the given learning rate and momentum. // Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is the quotient to be used in the adagrad computation iff // num_samples is used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true. void Update(float learning_rate, float momentum, float adam_beta,
virtual void Update(float learning_rate, float momentum, int num_samples); int num_samples) override;
// Sums the products of weight updates in *this and other, splitting into // Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in // positive (same direction) in *same and negative (different direction) in
// *changed. // *changed.
......
...@@ -132,7 +132,7 @@ int LSTM::InitWeights(float range, TRand* randomizer) { ...@@ -132,7 +132,7 @@ int LSTM::InitWeights(float range, TRand* randomizer) {
for (int w = 0; w < WT_COUNT; ++w) { for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue; if (w == GFS && !Is2D()) continue;
num_weights_ += gate_weights_[w].InitWeightsFloat( num_weights_ += gate_weights_[w].InitWeightsFloat(
ns_, na_ + 1, TestFlag(NF_ADA_GRAD), range, randomizer); ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
} }
if (softmax_ != NULL) { if (softmax_ != NULL) {
num_weights_ += softmax_->InitWeights(range, randomizer); num_weights_ += softmax_->InitWeights(range, randomizer);
...@@ -140,6 +140,19 @@ int LSTM::InitWeights(float range, TRand* randomizer) { ...@@ -140,6 +140,19 @@ int LSTM::InitWeights(float range, TRand* randomizer) {
return num_weights_; return num_weights_;
} }
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int LSTM::RemapOutputs(int old_no, const std::vector<int>& code_map) {
if (softmax_ != NULL) {
num_weights_ -= softmax_->num_weights();
num_weights_ += softmax_->RemapOutputs(old_no, code_map);
}
return num_weights_;
}
// Converts a float network to an int network. // Converts a float network to an int network.
void LSTM::ConvertToInt() { void LSTM::ConvertToInt() {
for (int w = 0; w < WT_COUNT; ++w) { for (int w = 0; w < WT_COUNT; ++w) {
...@@ -618,27 +631,22 @@ bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas, ...@@ -618,27 +631,22 @@ bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
if (softmax_ != NULL) { if (softmax_ != NULL) {
softmax_->FinishBackward(*softmax_errors_t); softmax_->FinishBackward(*softmax_errors_t);
} }
if (needs_to_backprop_) { return needs_to_backprop_;
// Normalize the inputerr in back_deltas.
back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
return true;
}
return false;
} }
// Updates the weights using the given learning rate and momentum. // Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is the quotient to be used in the adagrad computation iff // num_samples is used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true. void LSTM::Update(float learning_rate, float momentum, float adam_beta,
void LSTM::Update(float learning_rate, float momentum, int num_samples) { int num_samples) {
#if DEBUG_DETAIL > 3 #if DEBUG_DETAIL > 3
PrintW(); PrintW();
#endif #endif
for (int w = 0; w < WT_COUNT; ++w) { for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue; if (w == GFS && !Is2D()) continue;
gate_weights_[w].Update(learning_rate, momentum, num_samples); gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
} }
if (softmax_ != NULL) { if (softmax_ != NULL) {
softmax_->Update(learning_rate, momentum, num_samples); softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
} }
#if DEBUG_DETAIL > 3 #if DEBUG_DETAIL > 3
PrintDW(); PrintDW();
......
...@@ -76,6 +76,12 @@ class LSTM : public Network { ...@@ -76,6 +76,12 @@ class LSTM : public Network {
// Sets up the network for training. Initializes weights using weights of // Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`. // scale `range` picked according to the random number generator `randomizer`.
virtual int InitWeights(float range, TRand* randomizer); virtual int InitWeights(float range, TRand* randomizer);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int RemapOutputs(int old_no, const std::vector<int>& code_map) override;
// Converts a float network to an int network. // Converts a float network to an int network.
virtual void ConvertToInt(); virtual void ConvertToInt();
...@@ -99,10 +105,10 @@ class LSTM : public Network { ...@@ -99,10 +105,10 @@ class LSTM : public Network {
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch, NetworkScratch* scratch,
NetworkIO* back_deltas); NetworkIO* back_deltas);
// Updates the weights using the given learning rate and momentum. // Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is the quotient to be used in the adagrad computation iff // num_samples is used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true. void Update(float learning_rate, float momentum, float adam_beta,
virtual void Update(float learning_rate, float momentum, int num_samples); int num_samples) override;
// Sums the products of weight updates in *this and other, splitting into // Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in // positive (same direction) in *same and negative (different direction) in
// *changed. // *changed.
......
...@@ -55,9 +55,9 @@ LSTMRecognizer::LSTMRecognizer() ...@@ -55,9 +55,9 @@ LSTMRecognizer::LSTMRecognizer()
training_iteration_(0), training_iteration_(0),
sample_iteration_(0), sample_iteration_(0),
null_char_(UNICHAR_BROKEN), null_char_(UNICHAR_BROKEN),
weight_range_(0.0f),
learning_rate_(0.0f), learning_rate_(0.0f),
momentum_(0.0f), momentum_(0.0f),
adam_beta_(0.0f),
dict_(NULL), dict_(NULL),
search_(NULL), search_(NULL),
debug_win_(NULL) {} debug_win_(NULL) {}
...@@ -94,7 +94,7 @@ bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const { ...@@ -94,7 +94,7 @@ bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const {
if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1) if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
return false; return false;
if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false; if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false;
if (fp->FWrite(&weight_range_, sizeof(weight_range_), 1) != 1) return false; if (fp->FWrite(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false;
if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false; if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false;
if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false; if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false;
if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false; if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false;
...@@ -120,8 +120,7 @@ bool LSTMRecognizer::DeSerialize(const TessdataManager* mgr, TFile* fp) { ...@@ -120,8 +120,7 @@ bool LSTMRecognizer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
if (fp->FReadEndian(&sample_iteration_, sizeof(sample_iteration_), 1) != 1) if (fp->FReadEndian(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
return false; return false;
if (fp->FReadEndian(&null_char_, sizeof(null_char_), 1) != 1) return false; if (fp->FReadEndian(&null_char_, sizeof(null_char_), 1) != 1) return false;
if (fp->FReadEndian(&weight_range_, sizeof(weight_range_), 1) != 1) if (fp->FReadEndian(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false;
return false;
if (fp->FReadEndian(&learning_rate_, sizeof(learning_rate_), 1) != 1) if (fp->FReadEndian(&learning_rate_, sizeof(learning_rate_), 1) != 1)
return false; return false;
if (fp->FReadEndian(&momentum_, sizeof(momentum_), 1) != 1) return false; if (fp->FReadEndian(&momentum_, sizeof(momentum_), 1) != 1) return false;
...@@ -207,14 +206,22 @@ void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output, ...@@ -207,14 +206,22 @@ void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output,
STATS stats(0, kOutputScale + 1); STATS stats(0, kOutputScale + 1);
for (int t = 0; t < outputs.Width(); ++t) { for (int t = 0; t < outputs.Width(); ++t) {
int best_label = outputs.BestLabel(t, NULL); int best_label = outputs.BestLabel(t, NULL);
if (best_label != null_char_ || t == 0) { if (best_label != null_char_) {
float best_output = outputs.f(t)[best_label]; float best_output = outputs.f(t)[best_label];
stats.add(static_cast<int>(kOutputScale * best_output), 1); stats.add(static_cast<int>(kOutputScale * best_output), 1);
} }
} }
*min_output = static_cast<float>(stats.min_bucket()) / kOutputScale; // If the output is all nulls it could be that the photometric interpretation
*mean_output = stats.mean() / kOutputScale; // is wrong, so make it look bad, so the other way can win, even if not great.
*sd = stats.sd() / kOutputScale; if (stats.get_total() == 0) {
*min_output = 0.0f;
*mean_output = 0.0f;
*sd = 1.0f;
} else {
*min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
*mean_output = stats.mean() / kOutputScale;
*sd = stats.sd() / kOutputScale;
}
} }
// Recognizes the image_data, returning the labels, // Recognizes the image_data, returning the labels,
......
...@@ -45,8 +45,6 @@ class ImageData; ...@@ -45,8 +45,6 @@ class ImageData;
// Enum indicating training mode control flags. // Enum indicating training mode control flags.
enum TrainingFlags { enum TrainingFlags {
TF_INT_MODE = 1, TF_INT_MODE = 1,
TF_AUTO_HARDEN = 2,
TF_ROUND_ROBIN_TRAINING = 16,
TF_COMPRESS_UNICHARSET = 64, TF_COMPRESS_UNICHARSET = 64,
}; };
...@@ -69,9 +67,6 @@ class LSTMRecognizer { ...@@ -69,9 +67,6 @@ class LSTMRecognizer {
double learning_rate() const { double learning_rate() const {
return learning_rate_; return learning_rate_;
} }
bool IsHardening() const {
return (training_flags_ & TF_AUTO_HARDEN) != 0;
}
LossType OutputLossType() const { LossType OutputLossType() const {
if (network_ == nullptr) return LT_NONE; if (network_ == nullptr) return LT_NONE;
StaticShape shape; StaticShape shape;
...@@ -84,11 +79,6 @@ class LSTMRecognizer { ...@@ -84,11 +79,6 @@ class LSTMRecognizer {
bool IsRecoding() const { bool IsRecoding() const {
return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0; return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
} }
// Returns the cache strategy for the DocumentCache.
CachingStrategy CacheStrategy() const {
return training_flags_ & TF_ROUND_ROBIN_TRAINING ? CS_ROUND_ROBIN
: CS_SEQUENTIAL;
}
// Returns true if the network is a TensorFlow network. // Returns true if the network is a TensorFlow network.
bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; } bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
// Returns a vector of layer ids that can be passed to other layer functions // Returns a vector of layer ids that can be passed to other layer functions
...@@ -137,10 +127,10 @@ class LSTMRecognizer { ...@@ -137,10 +127,10 @@ class LSTMRecognizer {
series->ScaleLayerLearningRate(&id[1], factor); series->ScaleLayerLearningRate(&id[1], factor);
} }
// True if the network is using adagrad to train.
bool IsUsingAdaGrad() const { return network_->TestFlag(NF_ADA_GRAD); }
// Provides access to the UNICHARSET that this classifier works with. // Provides access to the UNICHARSET that this classifier works with.
const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; } const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
// Provides access to the UnicharCompress that this classifier works with.
const UnicharCompress& GetRecoder() const { return recoder_; }
// Provides access to the Dict that this classifier works with. // Provides access to the Dict that this classifier works with.
const Dict* GetDict() const { return dict_; } const Dict* GetDict() const { return dict_; }
// Sets the sample iteration to the given value. The sample_iteration_ // Sets the sample iteration to the given value. The sample_iteration_
...@@ -215,6 +205,12 @@ class LSTMRecognizer { ...@@ -215,6 +205,12 @@ class LSTMRecognizer {
const GenericVector<int>& label_coords, const GenericVector<int>& label_coords,
const char* window_name, const char* window_name,
ScrollView** window); ScrollView** window);
// Converts the network output to a sequence of labels. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
GenericVector<int>* xcoords);
protected: protected:
// Sets the random seed from the sample_iteration_; // Sets the random seed from the sample_iteration_;
...@@ -241,12 +237,6 @@ class LSTMRecognizer { ...@@ -241,12 +237,6 @@ class LSTMRecognizer {
void DebugActivationRange(const NetworkIO& outputs, const char* label, void DebugActivationRange(const NetworkIO& outputs, const char* label,
int best_choice, int x_start, int x_end); int best_choice, int x_start, int x_end);
// Converts the network output to a sequence of labels. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
GenericVector<int>* xcoords);
// As LabelsViaCTC except that this function constructs the best path that // As LabelsViaCTC except that this function constructs the best path that
// contains only legal sequences of subcodes for recoder_. // contains only legal sequences of subcodes for recoder_.
void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels, void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
...@@ -290,11 +280,11 @@ class LSTMRecognizer { ...@@ -290,11 +280,11 @@ class LSTMRecognizer {
// Index in softmax of null character. May take the value UNICHAR_BROKEN or // Index in softmax of null character. May take the value UNICHAR_BROKEN or
// ccutil_.unicharset.size(). // ccutil_.unicharset.size().
inT32 null_char_; inT32 null_char_;
// Range used for the initial random numbers in the weights.
float weight_range_;
// Learning rate and momentum multipliers of deltas in backprop. // Learning rate and momentum multipliers of deltas in backprop.
float learning_rate_; float learning_rate_;
float momentum_; float momentum_;
// Smoothing factor for 2nd moment of gradients.
float adam_beta_;
// === NOT SERIALIZED. // === NOT SERIALIZED.
TRand randomizer_; TRand randomizer_;
......
...@@ -123,11 +123,45 @@ LSTMTrainer::~LSTMTrainer() { ...@@ -123,11 +123,45 @@ LSTMTrainer::~LSTMTrainer() {
// Tries to deserialize a trainer from the given file and silently returns // Tries to deserialize a trainer from the given file and silently returns
// false in case of failure. // false in case of failure.
bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) { bool LSTMTrainer::TryLoadingCheckpoint(const char* filename,
const char* old_traineddata) {
GenericVector<char> data; GenericVector<char> data;
if (!(*file_reader_)(filename, &data)) return false; if (!(*file_reader_)(filename, &data)) return false;
tprintf("Loaded file %s, unpacking...\n", filename); tprintf("Loaded file %s, unpacking...\n", filename);
<<<<<<< Updated upstream
return checkpoint_reader_->Run(data, this); return checkpoint_reader_->Run(data, this);
=======
if (!checkpoint_reader_->Run(data, this)) return false;
StaticShape shape = network_->OutputShape(network_->InputShape());
if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
network_->NumOutputs() == recoder_.code_range()) ||
filename == old_traineddata) {
return true; // Normal checkpoint load complete.
}
tprintf("Code range changed from %d to %d!!\n", network_->NumOutputs(),
recoder_.code_range());
if (old_traineddata == nullptr || *old_traineddata == '\0') {
tprintf("Must supply the old traineddata for code conversion!\n");
return false;
}
TessdataManager old_mgr;
ASSERT_HOST(old_mgr.Init(old_traineddata));
TFile fp;
if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false;
UNICHARSET old_chset;
if (!old_chset.load_from_file(&fp, false)) return false;
if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false;
UnicharCompress old_recoder;
if (!old_recoder.DeSerialize(&fp)) return false;
std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
// Set the null_char_ to the new value.
int old_null_char = null_char_;
SetNullChar();
// Map the softmax(s) in the network.
network_->RemapOutputs(old_recoder.code_range(), code_map);
tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
return true;
>>>>>>> Stashed changes
} }
// Initializes the trainer with a network_spec in the network description // Initializes the trainer with a network_spec in the network description
...@@ -138,11 +172,13 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) { ...@@ -138,11 +172,13 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) {
// Note: Be sure to call InitCharSet before InitNetwork! // Note: Be sure to call InitCharSet before InitNetwork!
bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index, bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index,
int net_flags, float weight_range, int net_flags, float weight_range,
float learning_rate, float momentum) { float learning_rate, float momentum,
float adam_beta) {
mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec.string()); mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec.string());
weight_range_ = weight_range; adam_beta_ = adam_beta;
learning_rate_ = learning_rate; learning_rate_ = learning_rate;
momentum_ = momentum; momentum_ = momentum;
SetNullChar();
if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec, if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec,
append_index, net_flags, weight_range, append_index, net_flags, weight_range,
&randomizer_, &network_)) { &randomizer_, &network_)) {
...@@ -151,9 +187,10 @@ bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index, ...@@ -151,9 +187,10 @@ bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index,
network_str_ += network_spec; network_str_ += network_spec;
tprintf("Built network:%s from request %s\n", tprintf("Built network:%s from request %s\n",
network_->spec().string(), network_spec.string()); network_->spec().string(), network_spec.string());
tprintf("Training parameters:\n Debug interval = %d," tprintf(
" weights = %g, learning rate = %g, momentum=%g\n", "Training parameters:\n Debug interval = %d,"
debug_interval_, weight_range_, learning_rate_, momentum_); " weights = %g, learning rate = %g, momentum=%g\n",
debug_interval_, weight_range, learning_rate_, momentum_);
tprintf("null char=%d\n", null_char_); tprintf("null char=%d\n", null_char_);
return true; return true;
} }
...@@ -606,8 +643,6 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, ...@@ -606,8 +643,6 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
LR_SAME, // Learning rate will stay the same. LR_SAME, // Learning rate will stay the same.
LR_COUNT // Size of arrays. LR_COUNT // Size of arrays.
}; };
// Epsilon is so small that it may as well be zero, but still positive.
const double kEpsilon = 1.0e-30;
GenericVector<STRING> layers = EnumerateLayers(); GenericVector<STRING> layers = EnumerateLayers();
int num_layers = layers.size(); int num_layers = layers.size();
GenericVector<int> num_weights; GenericVector<int> num_weights;
...@@ -636,7 +671,7 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, ...@@ -636,7 +671,7 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
LSTMTrainer copy_trainer; LSTMTrainer copy_trainer;
samples_trainer->ReadTrainingDump(orig_trainer, &copy_trainer); samples_trainer->ReadTrainingDump(orig_trainer, &copy_trainer);
// Clear the updates, doing nothing else. // Clear the updates, doing nothing else.
copy_trainer.network_->Update(0.0, 0.0, 0); copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
// Adjust the learning rate in each layer. // Adjust the learning rate in each layer.
for (int i = 0; i < num_layers; ++i) { for (int i = 0; i < num_layers; ++i) {
if (num_weights[i] == 0) continue; if (num_weights[i] == 0) continue;
...@@ -656,9 +691,11 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, ...@@ -656,9 +691,11 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
LSTMTrainer layer_trainer; LSTMTrainer layer_trainer;
samples_trainer->ReadTrainingDump(updated_trainer, &layer_trainer); samples_trainer->ReadTrainingDump(updated_trainer, &layer_trainer);
Network* layer = layer_trainer.GetLayer(layers[i]); Network* layer = layer_trainer.GetLayer(layers[i]);
// Update the weights in just the layer, and also zero the updates // Update the weights in just the layer, using Adam if enabled.
// matrix (to epsilon). layer->Update(0.0, momentum_, adam_beta_,
layer->Update(0.0, kEpsilon, 0); layer_trainer.training_iteration_ + 1);
// Zero the updates matrix again.
layer->Update(0.0, 0.0, 0.0, 0);
// Train again on the same sample, again holding back the updates. // Train again on the same sample, again holding back the updates.
layer_trainer.TrainOnLine(trainingdata, true); layer_trainer.TrainOnLine(trainingdata, true);
// Count the sign changes in the updates in layer vs in copy_trainer. // Count the sign changes in the updates in layer vs in copy_trainer.
...@@ -773,7 +810,7 @@ Trainability LSTMTrainer::TrainOnLine(const ImageData* trainingdata, ...@@ -773,7 +810,7 @@ Trainability LSTMTrainer::TrainOnLine(const ImageData* trainingdata,
training_iteration() > training_iteration() >
last_perfect_training_iteration_ + perfect_delay_)) { last_perfect_training_iteration_ + perfect_delay_)) {
network_->Backward(debug, targets, &scratch_space_, &bp_deltas); network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
network_->Update(learning_rate_, batch ? -1.0f : momentum_, network_->Update(learning_rate_, batch ? -1.0f : momentum_, adam_beta_,
training_iteration_ + 1); training_iteration_ + 1);
} }
#ifndef GRAPHICS_DISABLED #ifndef GRAPHICS_DISABLED
...@@ -928,6 +965,41 @@ void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) { ...@@ -928,6 +965,41 @@ void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) {
error_rates_[type] = 100.0 * new_error; error_rates_[type] = 100.0 * new_error;
} }
// Helper generates a map from each current recoder_ code (ie softmax index)
// to the corresponding old_recoder code, or -1 if there isn't one.
std::vector<int> LSTMTrainer::MapRecoder(
const UNICHARSET& old_chset, const UnicharCompress& old_recoder) const {
int num_new_codes = recoder_.code_range();
int num_new_unichars = GetUnicharset().size();
std::vector<int> code_map(num_new_codes, -1);
for (int c = 0; c < num_new_codes; ++c) {
int old_code = -1;
// Find all new unichar_ids that recode to something that includes c.
// The <= is to include the null char, which may be beyond the unicharset.
for (int uid = 0; uid <= num_new_unichars; ++uid) {
RecodedCharID codes;
int length = recoder_.EncodeUnichar(uid, &codes);
int code_index = 0;
while (code_index < length && codes(code_index) != c) ++code_index;
if (code_index == length) continue;
// The old unicharset must have the same unichar.
int old_uid =
uid < num_new_unichars
? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
: old_chset.size() - 1;
if (old_uid == INVALID_UNICHAR_ID) continue;
// The encoding of old_uid at the same code_index is the old code.
RecodedCharID old_codes;
if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
old_code = old_codes(code_index);
break;
}
}
code_map[c] = old_code;
}
return code_map;
}
// Private version of InitCharSet above finishes the job after initializing // Private version of InitCharSet above finishes the job after initializing
// the mgr_ data member. // the mgr_ data member.
void LSTMTrainer::InitCharSet() { void LSTMTrainer::InitCharSet() {
...@@ -939,6 +1011,11 @@ void LSTMTrainer::InitCharSet() { ...@@ -939,6 +1011,11 @@ void LSTMTrainer::InitCharSet() {
"Must provide a traineddata containing lstm_unicharset and" "Must provide a traineddata containing lstm_unicharset and"
" lstm_recoder!\n" != nullptr); " lstm_recoder!\n" != nullptr);
} }
SetNullChar();
}
// Helper computes and sets the null_char_.
void LSTMTrainer::SetNullChar() {
null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN
: GetUnicharset().size(); : GetUnicharset().size();
RecodedCharID code; RecodedCharID code;
......
...@@ -98,8 +98,15 @@ class LSTMTrainer : public LSTMRecognizer { ...@@ -98,8 +98,15 @@ class LSTMTrainer : public LSTMRecognizer {
virtual ~LSTMTrainer(); virtual ~LSTMTrainer();
// Tries to deserialize a trainer from the given file and silently returns // Tries to deserialize a trainer from the given file and silently returns
<<<<<<< Updated upstream
// false in case of failure. // false in case of failure.
bool TryLoadingCheckpoint(const char* filename); bool TryLoadingCheckpoint(const char* filename);
=======
// false in case of failure. If old_traineddata is not null, then it is
// assumed that the character set is to be re-mapped from old_traininddata to
// the new, with consequent change in weight matrices etc.
bool TryLoadingCheckpoint(const char* filename, const char* old_traineddata);
>>>>>>> Stashed changes
// Initializes the character set encode/decode mechanism directly from a // Initializes the character set encode/decode mechanism directly from a
// previously setup traineddata containing dawgs, UNICHARSET and // previously setup traineddata containing dawgs, UNICHARSET and
...@@ -120,7 +127,8 @@ class LSTMTrainer : public LSTMRecognizer { ...@@ -120,7 +127,8 @@ class LSTMTrainer : public LSTMRecognizer {
// For other args see NetworkBuilder::InitNetwork. // For other args see NetworkBuilder::InitNetwork.
// Note: Be sure to call InitCharSet before InitNetwork! // Note: Be sure to call InitCharSet before InitNetwork!
bool InitNetwork(const STRING& network_spec, int append_index, int net_flags, bool InitNetwork(const STRING& network_spec, int append_index, int net_flags,
float weight_range, float learning_rate, float momentum); float weight_range, float learning_rate, float momentum,
float adam_beta);
// Initializes a trainer from a serialized TFNetworkModel proto. // Initializes a trainer from a serialized TFNetworkModel proto.
// Returns the global step of TensorFlow graph or 0 if failed. // Returns the global step of TensorFlow graph or 0 if failed.
// Building a compatible TF graph: See tfnetwork.proto. // Building a compatible TF graph: See tfnetwork.proto.
...@@ -320,11 +328,17 @@ class LSTMTrainer : public LSTMRecognizer { ...@@ -320,11 +328,17 @@ class LSTMTrainer : public LSTMRecognizer {
// Fills the whole error buffer of the given type with the given value. // Fills the whole error buffer of the given type with the given value.
void FillErrorBuffer(double new_error, ErrorTypes type); void FillErrorBuffer(double new_error, ErrorTypes type);
// Helper generates a map from each current recoder_ code (ie softmax index)
// to the corresponding old_recoder code, or -1 if there isn't one.
std::vector<int> MapRecoder(const UNICHARSET& old_chset,
const UnicharCompress& old_recoder) const;
protected: protected:
// Private version of InitCharSet above finishes the job after initializing // Private version of InitCharSet above finishes the job after initializing
// the mgr_ data member. // the mgr_ data member.
void InitCharSet(); void InitCharSet();
// Helper computes and sets the null_char_.
void SetNullChar();
// Factored sub-constructor sets up reasonable default values. // Factored sub-constructor sets up reasonable default values.
void EmptyConstructor(); void EmptyConstructor();
......
...@@ -85,7 +85,7 @@ enum NetworkType { ...@@ -85,7 +85,7 @@ enum NetworkType {
enum NetworkFlags { enum NetworkFlags {
// Network forward/backprop behavior. // Network forward/backprop behavior.
NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer. NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
NF_ADA_GRAD = 128, // Weight-specific learning rate. NF_ADAM = 128, // Weight-specific learning rate.
}; };
// State of training and desired state used in SetEnableTraining. // State of training and desired state used in SetEnableTraining.
...@@ -172,6 +172,14 @@ class Network { ...@@ -172,6 +172,14 @@ class Network {
// and should not be deleted by any of the networks. // and should not be deleted by any of the networks.
// Returns the number of weights initialized. // Returns the number of weights initialized.
virtual int InitWeights(float range, TRand* randomizer); virtual int InitWeights(float range, TRand* randomizer);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
virtual int RemapOutputs(int old_no, const std::vector<int>& code_map) {
return 0;
}
// Converts a float network to an int network. // Converts a float network to an int network.
virtual void ConvertToInt() {} virtual void ConvertToInt() {}
...@@ -212,10 +220,10 @@ class Network { ...@@ -212,10 +220,10 @@ class Network {
// Should be overridden by subclasses, but NOT called by their DeSerialize. // Should be overridden by subclasses, but NOT called by their DeSerialize.
virtual bool DeSerialize(TFile* fp); virtual bool DeSerialize(TFile* fp);
// Updates the weights using the given learning rate and momentum. // Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is the quotient to be used in the adagrad computation iff // num_samples is used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true. virtual void Update(float learning_rate, float momentum, float adam_beta,
virtual void Update(float learning_rate, float momentum, int num_samples) {} int num_samples) {}
// Sums the products of weight updates in *this and other, splitting into // Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in // positive (same direction) in *same and negative (different direction) in
// *changed. // *changed.
......
...@@ -57,6 +57,19 @@ int Plumbing::InitWeights(float range, TRand* randomizer) { ...@@ -57,6 +57,19 @@ int Plumbing::InitWeights(float range, TRand* randomizer) {
return num_weights_; return num_weights_;
} }
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int Plumbing::RemapOutputs(int old_no, const std::vector<int>& code_map) {
num_weights_ = 0;
for (int i = 0; i < stack_.size(); ++i) {
num_weights_ += stack_[i]->RemapOutputs(old_no, code_map);
}
return num_weights_;
}
// Converts a float network to an int network. // Converts a float network to an int network.
void Plumbing::ConvertToInt() { void Plumbing::ConvertToInt() {
for (int i = 0; i < stack_.size(); ++i) for (int i = 0; i < stack_.size(); ++i)
...@@ -204,10 +217,10 @@ bool Plumbing::DeSerialize(TFile* fp) { ...@@ -204,10 +217,10 @@ bool Plumbing::DeSerialize(TFile* fp) {
return true; return true;
} }
// Updates the weights using the given learning rate and momentum. // Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is the quotient to be used in the adagrad computation iff // num_samples is used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true. void Plumbing::Update(float learning_rate, float momentum, float adam_beta,
void Plumbing::Update(float learning_rate, float momentum, int num_samples) { int num_samples) {
for (int i = 0; i < stack_.size(); ++i) { for (int i = 0; i < stack_.size(); ++i) {
if (network_flags_ & NF_LAYER_SPECIFIC_LR) { if (network_flags_ & NF_LAYER_SPECIFIC_LR) {
if (i < learning_rates_.size()) if (i < learning_rates_.size())
...@@ -216,7 +229,7 @@ void Plumbing::Update(float learning_rate, float momentum, int num_samples) { ...@@ -216,7 +229,7 @@ void Plumbing::Update(float learning_rate, float momentum, int num_samples) {
learning_rates_.push_back(learning_rate); learning_rates_.push_back(learning_rate);
} }
if (stack_[i]->IsTraining()) { if (stack_[i]->IsTraining()) {
stack_[i]->Update(learning_rate, momentum, num_samples); stack_[i]->Update(learning_rate, momentum, adam_beta, num_samples);
} }
} }
} }
......
...@@ -57,6 +57,12 @@ class Plumbing : public Network { ...@@ -57,6 +57,12 @@ class Plumbing : public Network {
// and should not be deleted by any of the networks. // and should not be deleted by any of the networks.
// Returns the number of weights initialized. // Returns the number of weights initialized.
virtual int InitWeights(float range, TRand* randomizer); virtual int InitWeights(float range, TRand* randomizer);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int RemapOutputs(int old_no, const std::vector<int>& code_map) override;
// Converts a float network to an int network. // Converts a float network to an int network.
virtual void ConvertToInt(); virtual void ConvertToInt();
...@@ -118,10 +124,10 @@ class Plumbing : public Network { ...@@ -118,10 +124,10 @@ class Plumbing : public Network {
// Reads from the given file. Returns false in case of error. // Reads from the given file. Returns false in case of error.
virtual bool DeSerialize(TFile* fp); virtual bool DeSerialize(TFile* fp);
// Updates the weights using the given learning rate and momentum. // Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is the quotient to be used in the adagrad computation iff // num_samples is used in the adam computation iff use_adam_ is true.
// use_ada_grad_ is true. void Update(float learning_rate, float momentum, float adam_beta,
virtual void Update(float learning_rate, float momentum, int num_samples); int num_samples) override;
// Sums the products of weight updates in *this and other, splitting into // Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in // positive (same direction) in *same and negative (different direction) in
// *changed. // *changed.
......
...@@ -49,7 +49,7 @@ StaticShape Series::OutputShape(const StaticShape& input_shape) const { ...@@ -49,7 +49,7 @@ StaticShape Series::OutputShape(const StaticShape& input_shape) const {
// Note that series has its own implementation just for debug purposes. // Note that series has its own implementation just for debug purposes.
int Series::InitWeights(float range, TRand* randomizer) { int Series::InitWeights(float range, TRand* randomizer) {
num_weights_ = 0; num_weights_ = 0;
tprintf("Num outputs,weights in serial:\n"); tprintf("Num outputs,weights in Series:\n");
for (int i = 0; i < stack_.size(); ++i) { for (int i = 0; i < stack_.size(); ++i) {
int weights = stack_[i]->InitWeights(range, randomizer); int weights = stack_[i]->InitWeights(range, randomizer);
tprintf(" %s:%d, %d\n", tprintf(" %s:%d, %d\n",
...@@ -60,6 +60,25 @@ int Series::InitWeights(float range, TRand* randomizer) { ...@@ -60,6 +60,25 @@ int Series::InitWeights(float range, TRand* randomizer) {
return num_weights_; return num_weights_;
} }
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int Series::RemapOutputs(int old_no, const std::vector<int>& code_map) {
num_weights_ = 0;
tprintf("Num (Extended) outputs,weights in Series:\n");
for (int i = 0; i < stack_.size(); ++i) {
int weights = stack_[i]->RemapOutputs(old_no, code_map);
tprintf(" %s:%d, %d\n", stack_[i]->spec().string(),
stack_[i]->NumOutputs(), weights);
num_weights_ += weights;
}
tprintf("Total weights = %d\n", num_weights_);
no_ = stack_.back()->NumOutputs();
return num_weights_;
}
// Sets needs_to_backprop_ to needs_backprop and returns true if // Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward // needs_backprop || any weights in this network so the next layer forward
// can be told to produce backprop for this layer if needed. // can be told to produce backprop for this layer if needed.
......
...@@ -46,6 +46,12 @@ class Series : public Plumbing { ...@@ -46,6 +46,12 @@ class Series : public Plumbing {
// scale `range` picked according to the random number generator `randomizer`. // scale `range` picked according to the random number generator `randomizer`.
// Returns the number of weights initialized. // Returns the number of weights initialized.
virtual int InitWeights(float range, TRand* randomizer); virtual int InitWeights(float range, TRand* randomizer);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights. Only operates on Softmax layers with old_no outputs.
int RemapOutputs(int old_no, const std::vector<int>& code_map) override;
// Sets needs_to_backprop_ to needs_backprop and returns true if // Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward // needs_backprop || any weights in this network so the next layer forward
......
...@@ -26,6 +26,11 @@ ...@@ -26,6 +26,11 @@
namespace tesseract { namespace tesseract {
// Number of iterations after which the correction effectively becomes unity.
const int kAdamCorrectionIterations = 200000;
// Epsilon in Adam to prevent division by zero.
const double kAdamEpsilon = 1e-8;
// Copies the whole input transposed, converted to double, into *this. // Copies the whole input transposed, converted to double, into *this.
void TransposedArray::Transpose(const GENERIC_2D_ARRAY<double>& input) { void TransposedArray::Transpose(const GENERIC_2D_ARRAY<double>& input) {
int width = input.dim1(); int width = input.dim1();
...@@ -36,7 +41,7 @@ void TransposedArray::Transpose(const GENERIC_2D_ARRAY<double>& input) { ...@@ -36,7 +41,7 @@ void TransposedArray::Transpose(const GENERIC_2D_ARRAY<double>& input) {
// Sets up the network for training. Initializes weights using weights of // Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`. // scale `range` picked according to the random number generator `randomizer`.
int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad, int WeightMatrix::InitWeightsFloat(int no, int ni, bool use_adam,
float weight_range, TRand* randomizer) { float weight_range, TRand* randomizer) {
int_mode_ = false; int_mode_ = false;
wf_.Resize(no, ni, 0.0); wf_.Resize(no, ni, 0.0);
...@@ -47,11 +52,37 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad, ...@@ -47,11 +52,37 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
} }
} }
} }
use_ada_grad_ = ada_grad; use_adam_ = use_adam;
InitBackward(); InitBackward();
return ni * no; return ni * no;
} }
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights.
int WeightMatrix::RemapOutputs(const std::vector<int>& code_map) {
GENERIC_2D_ARRAY<double> old_wf(wf_);
int old_no = wf_.dim1();
int new_no = code_map.size();
int ni = wf_.dim2();
std::vector<double> means(ni, 0.0);
for (int c = 0; c < old_no; ++c) {
const double* weights = wf_[c];
for (int i = 0; i < ni; ++i) means[i] += weights[i];
}
for (double& mean : means) mean /= old_no;
wf_.ResizeNoInit(new_no, ni);
InitBackward();
for (int dest = 0; dest < new_no; ++dest) {
int src = code_map[dest];
const double* src_data = src >= 0 ? old_wf[src] : means.data();
memcpy(wf_[dest], src_data, ni * sizeof(*src_data));
}
return ni * new_no;
}
// Converts a float network to an int network. Each set of input weights that // Converts a float network to an int network. Each set of input weights that
// corresponds to a single output weight is converted independently: // corresponds to a single output weight is converted independently:
// Compute the max absolute value of the weight set. // Compute the max absolute value of the weight set.
...@@ -90,13 +121,13 @@ void WeightMatrix::InitBackward() { ...@@ -90,13 +121,13 @@ void WeightMatrix::InitBackward() {
dw_.Resize(no, ni, 0.0); dw_.Resize(no, ni, 0.0);
updates_.Resize(no, ni, 0.0); updates_.Resize(no, ni, 0.0);
wf_t_.Transpose(wf_); wf_t_.Transpose(wf_);
if (use_ada_grad_) dw_sq_sum_.Resize(no, ni, 0.0); if (use_adam_) dw_sq_sum_.Resize(no, ni, 0.0);
} }
// Flag on mode to indicate that this weightmatrix uses inT8. // Flag on mode to indicate that this weightmatrix uses inT8.
const int kInt8Flag = 1; const int kInt8Flag = 1;
// Flag on mode to indicate that this weightmatrix uses ada grad. // Flag on mode to indicate that this weightmatrix uses adam.
const int kAdaGradFlag = 4; const int kAdamFlag = 4;
// Flag on mode to indicate that this weightmatrix uses double. Set // Flag on mode to indicate that this weightmatrix uses double. Set
// independently of kInt8Flag as even in int mode the scales can // independently of kInt8Flag as even in int mode the scales can
// be float or double. // be float or double.
...@@ -106,8 +137,8 @@ const int kDoubleFlag = 128; ...@@ -106,8 +137,8 @@ const int kDoubleFlag = 128;
bool WeightMatrix::Serialize(bool training, TFile* fp) const { bool WeightMatrix::Serialize(bool training, TFile* fp) const {
// For backward compatibility, add kDoubleFlag to mode to indicate the doubles // For backward compatibility, add kDoubleFlag to mode to indicate the doubles
// format, without errs, so we can detect and read old format weight matrices. // format, without errs, so we can detect and read old format weight matrices.
uinT8 mode = (int_mode_ ? kInt8Flag : 0) | uinT8 mode =
(use_ada_grad_ ? kAdaGradFlag : 0) | kDoubleFlag; (int_mode_ ? kInt8Flag : 0) | (use_adam_ ? kAdamFlag : 0) | kDoubleFlag;
if (fp->FWrite(&mode, sizeof(mode), 1) != 1) return false; if (fp->FWrite(&mode, sizeof(mode), 1) != 1) return false;
if (int_mode_) { if (int_mode_) {
if (!wi_.Serialize(fp)) return false; if (!wi_.Serialize(fp)) return false;
...@@ -115,7 +146,7 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const { ...@@ -115,7 +146,7 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
} else { } else {
if (!wf_.Serialize(fp)) return false; if (!wf_.Serialize(fp)) return false;
if (training && !updates_.Serialize(fp)) return false; if (training && !updates_.Serialize(fp)) return false;
if (training && use_ada_grad_ && !dw_sq_sum_.Serialize(fp)) return false; if (training && use_adam_ && !dw_sq_sum_.Serialize(fp)) return false;
} }
return true; return true;
} }
...@@ -126,7 +157,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) { ...@@ -126,7 +157,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
uinT8 mode = 0; uinT8 mode = 0;
if (fp->FRead(&mode, sizeof(mode), 1) != 1) return false; if (fp->FRead(&mode, sizeof(mode), 1) != 1) return false;
int_mode_ = (mode & kInt8Flag) != 0; int_mode_ = (mode & kInt8Flag) != 0;
use_ada_grad_ = (mode & kAdaGradFlag) != 0; use_adam_ = (mode & kAdamFlag) != 0;
if ((mode & kDoubleFlag) == 0) return DeSerializeOld(training, fp); if ((mode & kDoubleFlag) == 0) return DeSerializeOld(training, fp);
if (int_mode_) { if (int_mode_) {
if (!wi_.DeSerialize(fp)) return false; if (!wi_.DeSerialize(fp)) return false;
...@@ -136,7 +167,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) { ...@@ -136,7 +167,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
if (training) { if (training) {
InitBackward(); InitBackward();
if (!updates_.DeSerialize(fp)) return false; if (!updates_.DeSerialize(fp)) return false;
if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(fp)) return false; if (use_adam_ && !dw_sq_sum_.DeSerialize(fp)) return false;
} }
} }
return true; return true;
...@@ -247,19 +278,27 @@ void WeightMatrix::SumOuterTransposed(const TransposedArray& u, ...@@ -247,19 +278,27 @@ void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
} }
// Updates the weights using the given learning rate and momentum. // Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff // num_samples is the quotient to be used in the adam computation iff
// use_ada_grad_ is true. // use_adam_ is true.
void WeightMatrix::Update(double learning_rate, double momentum, void WeightMatrix::Update(double learning_rate, double momentum,
int num_samples) { double adam_beta, int num_samples) {
ASSERT_HOST(!int_mode_); ASSERT_HOST(!int_mode_);
if (use_ada_grad_ && num_samples > 0) { if (use_adam_ && num_samples > 0 && num_samples < kAdamCorrectionIterations) {
dw_sq_sum_.SumSquares(dw_); learning_rate *= sqrt(1.0 - pow(adam_beta, num_samples));
dw_.AdaGradScaling(dw_sq_sum_, num_samples); learning_rate /= 1.0 - pow(momentum, num_samples);
}
if (use_adam_ && num_samples > 0 && momentum > 0.0) {
dw_sq_sum_.SumSquares(dw_, adam_beta);
dw_ *= learning_rate * (1.0 - momentum);
updates_ *= momentum;
updates_ += dw_;
wf_.AdamUpdate(updates_, dw_sq_sum_, learning_rate * kAdamEpsilon);
} else {
dw_ *= learning_rate;
updates_ += dw_;
if (momentum > 0.0) wf_ += updates_;
if (momentum >= 0.0) updates_ *= momentum;
} }
dw_ *= learning_rate;
updates_ += dw_;
if (momentum > 0.0) wf_ += updates_;
if (momentum >= 0.0) updates_ *= momentum;
wf_t_.Transpose(wf_); wf_t_.Transpose(wf_);
} }
......
...@@ -62,14 +62,20 @@ class TransposedArray : public GENERIC_2D_ARRAY<double> { ...@@ -62,14 +62,20 @@ class TransposedArray : public GENERIC_2D_ARRAY<double> {
// backward steps with the matrix and updates to the weights. // backward steps with the matrix and updates to the weights.
class WeightMatrix { class WeightMatrix {
public: public:
WeightMatrix() : int_mode_(false), use_ada_grad_(false) {} WeightMatrix() : int_mode_(false), use_adam_(false) {}
// Sets up the network for training. Initializes weights using weights of // Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`. // scale `range` picked according to the random number generator `randomizer`.
// Note the order is outputs, inputs, as this is the order of indices to // Note the order is outputs, inputs, as this is the order of indices to
// the matrix, so the adjacent elements are multiplied by the input during // the matrix, so the adjacent elements are multiplied by the input during
// a forward operation. // a forward operation.
int InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range, int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range,
TRand* randomizer); TRand* randomizer);
// Changes the number of outputs to the size of the given code_map, copying
// the old weight matrix entries for each output from code_map[output] where
// non-negative, and uses the mean (over all outputs) of the existing weights
// for all outputs with negative code_map entries. Returns the new number of
// weights.
int RemapOutputs(const std::vector<int>& code_map);
// Converts a float network to an int network. Each set of input weights that // Converts a float network to an int network. Each set of input weights that
// corresponds to a single output weight is converted independently: // corresponds to a single output weight is converted independently:
...@@ -123,10 +129,10 @@ class WeightMatrix { ...@@ -123,10 +129,10 @@ class WeightMatrix {
// Runs parallel if requested. Note that inputs must be transposed. // Runs parallel if requested. Note that inputs must be transposed.
void SumOuterTransposed(const TransposedArray& u, const TransposedArray& v, void SumOuterTransposed(const TransposedArray& u, const TransposedArray& v,
bool parallel); bool parallel);
// Updates the weights using the given learning rate and momentum. // Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is the quotient to be used in the adagrad computation iff // num_samples is used in the Adam correction factor.
// use_ada_grad_ is true. void Update(double learning_rate, double momentum, double adam_beta,
void Update(double learning_rate, double momentum, int num_samples); int num_samples);
// Adds the dw_ in other to the dw_ is *this. // Adds the dw_ in other to the dw_ is *this.
void AddDeltas(const WeightMatrix& other); void AddDeltas(const WeightMatrix& other);
// Sums the products of weight updates in *this and other, splitting into // Sums the products of weight updates in *this and other, splitting into
...@@ -163,8 +169,8 @@ class WeightMatrix { ...@@ -163,8 +169,8 @@ class WeightMatrix {
TransposedArray wf_t_; TransposedArray wf_t_;
// Which of wf_ and wi_ are we actually using. // Which of wf_ and wi_ are we actually using.
bool int_mode_; bool int_mode_;
// True if we are running adagrad in this weight matrix. // True if we are running adam in this weight matrix.
bool use_ada_grad_; bool use_adam_;
// If we are using wi_, then scales_ is a factor to restore the row product // If we are using wi_, then scales_ is a factor to restore the row product
// with a vector to the correct range. // with a vector to the correct range.
GenericVector<double> scales_; GenericVector<double> scales_;
...@@ -172,8 +178,8 @@ class WeightMatrix { ...@@ -172,8 +178,8 @@ class WeightMatrix {
// amount to be added to wf_/wi_. // amount to be added to wf_/wi_.
GENERIC_2D_ARRAY<double> dw_; GENERIC_2D_ARRAY<double> dw_;
GENERIC_2D_ARRAY<double> updates_; GENERIC_2D_ARRAY<double> updates_;
// Iff use_ada_grad_, the sum of squares of dw_. The number of samples is // Iff use_adam_, the sum of squares of dw_. The number of samples is
// given to Update(). Serialized iff use_ada_grad_. // given to Update(). Serialized iff use_adam_.
GENERIC_2D_ARRAY<double> dw_sq_sum_; GENERIC_2D_ARRAY<double> dw_sq_sum_;
}; };
......
...@@ -34,8 +34,9 @@ INT_PARAM_FLAG(perfect_sample_delay, 0, ...@@ -34,8 +34,9 @@ INT_PARAM_FLAG(perfect_sample_delay, 0,
"How many imperfect samples between perfect ones."); "How many imperfect samples between perfect ones.");
DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent."); DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights."); DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
DOUBLE_PARAM_FLAG(learning_rate, 1.0e-4, "Weight factor for new deltas."); DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas.");
DOUBLE_PARAM_FLAG(momentum, 0.9, "Decay factor for repeating deltas."); DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas.");
DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas.");
INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images."); INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
STRING_PARAM_FLAG(continue_from, "", "Existing model to extend"); STRING_PARAM_FLAG(continue_from, "", "Existing model to extend");
STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models"); STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
...@@ -56,6 +57,11 @@ BOOL_PARAM_FLAG(debug_network, false, ...@@ -56,6 +57,11 @@ BOOL_PARAM_FLAG(debug_network, false,
INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations"); INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations");
STRING_PARAM_FLAG(traineddata, "", STRING_PARAM_FLAG(traineddata, "",
"Combined Dawgs/Unicharset/Recoder for language model"); "Combined Dawgs/Unicharset/Recoder for language model");
<<<<<<< Updated upstream
=======
STRING_PARAM_FLAG(old_traineddata, "",
"Previous traineddata arg when changing the character set");
>>>>>>> Stashed changes
// Number of training images to train between calls to MaintainCheckpoints. // Number of training images to train between calls to MaintainCheckpoints.
const int kNumPagesPerBatch = 100; const int kNumPagesPerBatch = 100;
...@@ -91,7 +97,7 @@ int main(int argc, char **argv) { ...@@ -91,7 +97,7 @@ int main(int argc, char **argv) {
// Reading something from an existing model doesn't require many flags, // Reading something from an existing model doesn't require many flags,
// so do it now and exit. // so do it now and exit.
if (FLAGS_stop_training || FLAGS_debug_network) { if (FLAGS_stop_training || FLAGS_debug_network) {
if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) { if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) {
tprintf("Failed to read continue from: %s\n", tprintf("Failed to read continue from: %s\n",
FLAGS_continue_from.c_str()); FLAGS_continue_from.c_str());
return 1; return 1;
...@@ -122,14 +128,17 @@ int main(int argc, char **argv) { ...@@ -122,14 +128,17 @@ int main(int argc, char **argv) {
} }
// Checkpoints always take priority if they are available. // Checkpoints always take priority if they are available.
if (trainer.TryLoadingCheckpoint(checkpoint_file.string()) || if (trainer.TryLoadingCheckpoint(checkpoint_file.string(), nullptr) ||
trainer.TryLoadingCheckpoint(checkpoint_bak.string())) { trainer.TryLoadingCheckpoint(checkpoint_bak.string(), nullptr)) {
tprintf("Successfully restored trainer from %s\n", tprintf("Successfully restored trainer from %s\n",
checkpoint_file.string()); checkpoint_file.string());
} else { } else {
if (!FLAGS_continue_from.empty()) { if (!FLAGS_continue_from.empty()) {
// Load a past model file to improve upon. // Load a past model file to improve upon.
if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) { if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
FLAGS_append_index >= 0
? FLAGS_continue_from.c_str()
: FLAGS_old_traineddata.c_str())) {
tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str()); tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
return 1; return 1;
} }
...@@ -147,7 +156,8 @@ int main(int argc, char **argv) { ...@@ -147,7 +156,8 @@ int main(int argc, char **argv) {
// We are initializing from scratch. // We are initializing from scratch.
if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index, if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
FLAGS_net_mode, FLAGS_weight_range, FLAGS_net_mode, FLAGS_weight_range,
FLAGS_learning_rate, FLAGS_momentum)) { FLAGS_learning_rate, FLAGS_momentum,
FLAGS_adam_beta)) {
tprintf("Failed to create network from spec: %s\n", tprintf("Failed to create network from spec: %s\n",
FLAGS_net_spec.c_str()); FLAGS_net_spec.c_str());
return 1; return 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册