tesseract  4.1.3
lstmtrainer.cpp
Go to the documentation of this file.
1 // File: lstmtrainer.cpp
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #define _USE_MATH_DEFINES // needed to get definition of M_SQRT1_2
19 
20 // Include automatically generated configuration file if running autoconf.
21 #ifdef HAVE_CONFIG_H
22 #include "config_auto.h"
23 #endif
24 
25 #include "lstmtrainer.h"
26 #include <string>
27 
28 #include "allheaders.h"
29 #include "boxread.h"
30 #include "ctc.h"
31 #include "imagedata.h"
32 #include "input.h"
33 #include "networkbuilder.h"
34 #include "ratngs.h"
35 #include "recodebeam.h"
36 #ifdef INCLUDE_TENSORFLOW
37 #include "tfnetwork.h"
38 #endif
39 #include "tprintf.h"
40 
41 #include "callcpp.h"
42 
43 namespace tesseract {
44 
45 // Min actual error rate increase to constitute divergence.
46 const double kMinDivergenceRate = 50.0;
47 // Min iterations since last best before acting on a stall.
48 const int kMinStallIterations = 10000;
49 // Fraction of current char error rate that sub_trainer_ has to be ahead
50 // before we declare the sub_trainer_ a success and switch to it.
51 const double kSubTrainerMarginFraction = 3.0 / 128;
52 // Factor to reduce learning rate on divergence.
53 const double kLearningRateDecay = M_SQRT1_2;
54 // LR adjustment iterations.
55 const int kNumAdjustmentIterations = 100;
56 // How often to add data to the error_graph_.
57 const int kErrorGraphInterval = 1000;
58 // Number of training images to train between calls to MaintainCheckpoints.
59 const int kNumPagesPerBatch = 100;
60 // Min percent error rate to consider start-up phase over.
61 const int kMinStartedErrorRate = 75;
62 // Error rate at which to transition to stage 1.
63 const double kStageTransitionThreshold = 10.0;
64 // Confidence beyond which the truth is more likely wrong than the recognizer.
65 const double kHighConfidence = 0.9375; // 15/16.
66 // Fraction of weight sign-changing total to constitute a definite improvement.
67 const double kImprovementFraction = 15.0 / 16.0;
68 // Fraction of last written best to make it worth writing another.
69 const double kBestCheckpointFraction = 31.0 / 32.0;
70 // Scale factor for display of target activations of CTC.
71 const int kTargetXScale = 5;
72 const int kTargetYScale = 100;
73 
75  : randomly_rotate_(false),
76  training_data_(0),
77  file_reader_(LoadDataFromFile),
78  file_writer_(SaveDataToFile),
79  checkpoint_reader_(
80  NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump)),
81  checkpoint_writer_(
82  NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump)),
83  sub_trainer_(nullptr) {
85  debug_interval_ = 0;
86 }
87 
89  CheckPointReader checkpoint_reader,
90  CheckPointWriter checkpoint_writer,
91  const char* model_base, const char* checkpoint_name,
92  int debug_interval, int64_t max_memory)
93  : randomly_rotate_(false),
94  training_data_(max_memory),
95  file_reader_(file_reader),
96  file_writer_(file_writer),
97  checkpoint_reader_(checkpoint_reader),
98  checkpoint_writer_(checkpoint_writer),
99  sub_trainer_(nullptr),
100  mgr_(file_reader) {
102  if (file_reader_ == nullptr) file_reader_ = LoadDataFromFile;
103  if (file_writer_ == nullptr) file_writer_ = SaveDataToFile;
104  if (checkpoint_reader_ == nullptr) {
107  }
108  if (checkpoint_writer_ == nullptr) {
111  }
112  debug_interval_ = debug_interval;
113  model_base_ = model_base;
114  checkpoint_name_ = checkpoint_name;
115 }
116 
118  delete align_win_;
119  delete target_win_;
120  delete ctc_win_;
121  delete recon_win_;
122  delete checkpoint_reader_;
123  delete checkpoint_writer_;
124  delete sub_trainer_;
125 }
126 
127 // Tries to deserialize a trainer from the given file and silently returns
128 // false in case of failure.
129 bool LSTMTrainer::TryLoadingCheckpoint(const char* filename,
130  const char* old_traineddata) {
131  GenericVector<char> data;
132  if (!(*file_reader_)(filename, &data)) return false;
133  tprintf("Loaded file %s, unpacking...\n", filename);
134  if (!checkpoint_reader_->Run(data, this)) return false;
135  if (IsIntMode()) {
136  tprintf("Error, %s is an integer (fast) model, cannot continue training\n", filename);
137  return false;
138  }
140  if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
142  filename == old_traineddata) {
143  return true; // Normal checkpoint load complete.
144  }
145  tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(),
146  recoder_.code_range());
147  if (old_traineddata == nullptr || *old_traineddata == '\0') {
148  tprintf("Must supply the old traineddata for code conversion!\n");
149  return false;
150  }
151  TessdataManager old_mgr;
152  ASSERT_HOST(old_mgr.Init(old_traineddata));
153  TFile fp;
154  if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false;
155  UNICHARSET old_chset;
156  if (!old_chset.load_from_file(&fp, false)) return false;
157  if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false;
158  UnicharCompress old_recoder;
159  if (!old_recoder.DeSerialize(&fp)) return false;
160  std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
161  // Set the null_char_ to the new value.
162  int old_null_char = null_char_;
163  SetNullChar();
164  // Map the softmax(s) in the network.
165  network_->RemapOutputs(old_recoder.code_range(), code_map);
166  tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
167  return true;
168 }
169 
170 // Initializes the trainer with a network_spec in the network description
171 // net_flags control network behavior according to the NetworkFlags enum.
172 // There isn't really much difference between them - only where the effects
173 // are implemented.
174 // For other args see NetworkBuilder::InitNetwork.
175 // Note: Be sure to call InitCharSet before InitNetwork!
176 bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index,
177  int net_flags, float weight_range,
178  float learning_rate, float momentum,
179  float adam_beta) {
180  mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec.string());
181  adam_beta_ = adam_beta;
183  momentum_ = momentum;
184  SetNullChar();
185  if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec,
186  append_index, net_flags, weight_range,
187  &randomizer_, &network_)) {
188  return false;
189  }
190  network_str_ += network_spec;
191  tprintf("Built network:%s from request %s\n",
192  network_->spec().string(), network_spec.string());
193  tprintf(
194  "Training parameters:\n Debug interval = %d,"
195  " weights = %g, learning rate = %g, momentum=%g\n",
196  debug_interval_, weight_range, learning_rate_, momentum_);
197  tprintf("null char=%d\n", null_char_);
198  return true;
199 }
200 
201 // Initializes a trainer from a serialized TFNetworkModel proto.
202 // Returns the global step of TensorFlow graph or 0 if failed.
203 #ifdef INCLUDE_TENSORFLOW
204 int LSTMTrainer::InitTensorFlowNetwork(const std::string& tf_proto) {
205  delete network_;
206  TFNetwork* tf_net = new TFNetwork("TensorFlow");
207  training_iteration_ = tf_net->InitFromProtoStr(tf_proto);
208  if (training_iteration_ == 0) {
209  tprintf("InitFromProtoStr failed!!\n");
210  return 0;
211  }
212  network_ = tf_net;
213  ASSERT_HOST(recoder_.code_range() == tf_net->num_classes());
214  return training_iteration_;
215 }
216 #endif
217 
218 // Resets all the iteration counters for fine tuning or traininng a head,
219 // where we want the error reporting to reset.
221  sample_iteration_ = 0;
225  best_error_rate_ = 100.0;
226  best_iteration_ = 0;
227  worst_error_rate_ = 0.0;
228  worst_iteration_ = 0;
231  perfect_delay_ = 0;
233  for (int i = 0; i < ET_COUNT; ++i) {
234  best_error_rates_[i] = 100.0;
235  worst_error_rates_[i] = 0.0;
237  error_rates_[i] = 100.0;
238  }
240 }
241 
242 // If the training sample is usable, grid searches for the optimal
243 // dict_ratio/cert_offset, and returns the results in a string of space-
244 // separated triplets of ratio,offset=worderr.
246  const ImageData* trainingdata, int iteration, double min_dict_ratio,
247  double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
248  double cert_offset_step, double max_cert_offset, STRING* results) {
249  sample_iteration_ = iteration;
250  NetworkIO fwd_outputs, targets;
251  Trainability result =
252  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
253  if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == nullptr)
254  return result;
255 
256  // Encode/decode the truth to get the normalization.
257  GenericVector<int> truth_labels, ocr_labels, xcoords;
258  ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
259  // NO-dict error.
260  RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), nullptr);
261  base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
262  nullptr);
263  base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
264  STRING truth_text = DecodeLabels(truth_labels);
265  STRING ocr_text = DecodeLabels(ocr_labels);
266  double baseline_error = ComputeWordError(&truth_text, &ocr_text);
267  results->add_str_double("0,0=", baseline_error);
268 
270  for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
271  for (double c = min_cert_offset; c < max_cert_offset;
272  c += cert_offset_step) {
273  search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty, nullptr);
274  search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
275  truth_text = DecodeLabels(truth_labels);
276  ocr_text = DecodeLabels(ocr_labels);
277  // This is destructive on both strings.
278  double word_error = ComputeWordError(&truth_text, &ocr_text);
279  if ((r == min_dict_ratio && c == min_cert_offset) ||
280  !std::isfinite(word_error)) {
281  STRING t = DecodeLabels(truth_labels);
282  STRING o = DecodeLabels(ocr_labels);
283  tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
284  t.string(), o.string(), word_error, truth_labels[0]);
285  }
286  results->add_str_double(" ", r);
287  results->add_str_double(",", c);
288  results->add_str_double("=", word_error);
289  }
290  }
291  return result;
292 }
293 
294 // Provides output on the distribution of weight values.
297 }
298 
299 // Loads a set of lstmf files that were created using the lstm.train config to
300 // tesseract into memory ready for training. Returns false if nothing was
301 // loaded.
303  CachingStrategy cache_strategy,
304  bool randomly_rotate) {
305  randomly_rotate_ = randomly_rotate;
307  return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_);
308 }
309 
310 // Keeps track of best and locally worst char error_rate and launches tests
311 // using tester, when a new min or max is reached.
312 // Writes checkpoints at appropriate times and builds and returns a log message
313 // to indicate progress. Returns false if nothing interesting happened.
315  PrepareLogMsg(log_msg);
316  double error_rate = CharError();
317  int iteration = learning_iteration();
318  if (iteration >= stall_iteration_ &&
319  error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
321  // It hasn't got any better in a long while, and is a margin worse than the
322  // best, so go back to the best model and try a different learning rate.
323  StartSubtrainer(log_msg);
324  }
325  SubTrainerResult sub_trainer_result = STR_NONE;
326  if (sub_trainer_ != nullptr) {
327  sub_trainer_result = UpdateSubtrainer(log_msg);
328  if (sub_trainer_result == STR_REPLACED) {
329  // Reset the inputs, as we have overwritten *this.
330  error_rate = CharError();
331  iteration = learning_iteration();
332  PrepareLogMsg(log_msg);
333  }
334  }
335  bool result = true; // Something interesting happened.
336  GenericVector<char> rec_model_data;
337  if (error_rate < best_error_rate_) {
338  SaveRecognitionDump(&rec_model_data);
339  log_msg->add_str_double(" New best char error = ", error_rate);
340  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
341  // If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
342  // just overwrote *this. In either case, we have finished with it.
343  delete sub_trainer_;
344  sub_trainer_ = nullptr;
347  log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage());
348  }
351  STRING best_model_name = DumpFilename();
352  if (!(*file_writer_)(best_trainer_, best_model_name.c_str())) {
353  *log_msg += " failed to write best model:";
354  } else {
355  *log_msg += " wrote best model:";
357  }
358  *log_msg += best_model_name;
359  }
360  } else if (error_rate > worst_error_rate_) {
361  SaveRecognitionDump(&rec_model_data);
362  log_msg->add_str_double(" New worst char error = ", error_rate);
363  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
366  // Error rate has ballooned. Go back to the best model.
367  *log_msg += "\nDivergence! ";
368  // Copy best_trainer_ before reading it, as it will get overwritten.
369  GenericVector<char> revert_data(best_trainer_);
370  if (checkpoint_reader_->Run(revert_data, this)) {
371  LogIterations("Reverted to", log_msg);
372  ReduceLearningRates(this, log_msg);
373  } else {
374  LogIterations("Failed to Revert at", log_msg);
375  }
376  // If it fails again, we will wait twice as long before reverting again.
377  stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
378  // Re-save the best trainer with the new learning rates and stall
379  // iteration.
381  }
382  } else {
383  // Something interesting happened only if the sub_trainer_ was trained.
384  result = sub_trainer_result != STR_NONE;
385  }
386  if (checkpoint_writer_ != nullptr && file_writer_ != nullptr &&
387  checkpoint_name_.length() > 0) {
388  // Write a current checkpoint.
389  GenericVector<char> checkpoint;
390  if (!checkpoint_writer_->Run(FULL, this, &checkpoint) ||
391  !(*file_writer_)(checkpoint, checkpoint_name_.c_str())) {
392  *log_msg += " failed to write checkpoint.";
393  } else {
394  *log_msg += " wrote checkpoint.";
395  }
396  }
397  *log_msg += "\n";
398  return result;
399 }
400 
401 // Builds a string containing a progress message with current error rates.
402 void LSTMTrainer::PrepareLogMsg(STRING* log_msg) const {
403  LogIterations("At", log_msg);
404  log_msg->add_str_double(", Mean rms=", error_rates_[ET_RMS]);
405  log_msg->add_str_double("%, delta=", error_rates_[ET_DELTA]);
406  log_msg->add_str_double("%, char train=", error_rates_[ET_CHAR_ERROR]);
407  log_msg->add_str_double("%, word train=", error_rates_[ET_WORD_RECERR]);
408  log_msg->add_str_double("%, skip ratio=", error_rates_[ET_SKIP_RATIO]);
409  *log_msg += "%, ";
410 }
411 
412 // Appends <intro_str> iteration learning_iteration()/training_iteration()/
413 // sample_iteration() to the log_msg.
414 void LSTMTrainer::LogIterations(const char* intro_str, STRING* log_msg) const {
415  *log_msg += intro_str;
416  log_msg->add_str_int(" iteration ", learning_iteration());
417  log_msg->add_str_int("/", training_iteration());
418  log_msg->add_str_int("/", sample_iteration());
419 }
420 
421 // Returns true and increments the training_stage_ if the error rate has just
422 // passed through the given threshold for the first time.
423 bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
424  if (best_error_rate_ < error_threshold &&
426  ++training_stage_;
427  return true;
428  }
429  return false;
430 }
431 
432 // Writes to the given file. Returns false in case of error.
434  const TessdataManager* mgr, TFile* fp) const {
435  if (!LSTMRecognizer::Serialize(mgr, fp)) return false;
436  if (!fp->Serialize(&learning_iteration_)) return false;
437  if (!fp->Serialize(&prev_sample_iteration_)) return false;
438  if (!fp->Serialize(&perfect_delay_)) return false;
439  if (!fp->Serialize(&last_perfect_training_iteration_)) return false;
440  for (const auto & error_buffer : error_buffers_) {
441  if (!error_buffer.Serialize(fp)) return false;
442  }
443  if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) return false;
444  if (!fp->Serialize(&training_stage_)) return false;
445  uint8_t amount = serialize_amount;
446  if (!fp->Serialize(&amount)) return false;
447  if (serialize_amount == LIGHT) return true; // We are done.
448  if (!fp->Serialize(&best_error_rate_)) return false;
449  if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
450  if (!fp->Serialize(&best_iteration_)) return false;
451  if (!fp->Serialize(&worst_error_rate_)) return false;
452  if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
453  if (!fp->Serialize(&worst_iteration_)) return false;
454  if (!fp->Serialize(&stall_iteration_)) return false;
455  if (!best_model_data_.Serialize(fp)) return false;
456  if (!worst_model_data_.Serialize(fp)) return false;
457  if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp))
458  return false;
459  GenericVector<char> sub_data;
460  if (sub_trainer_ != nullptr && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data))
461  return false;
462  if (!sub_data.Serialize(fp)) return false;
463  if (!best_error_history_.Serialize(fp)) return false;
464  if (!best_error_iterations_.Serialize(fp)) return false;
465  return fp->Serialize(&improvement_steps_);
466 }
467 
468 // Reads from the given file. Returns false in case of error.
469 // NOTE: It is assumed that the trainer is never read cross-endian.
471  if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false;
472  if (!fp->DeSerialize(&learning_iteration_)) {
473  // Special case. If we successfully decoded the recognizer, but fail here
474  // then it means we were just given a recognizer, so issue a warning and
475  // allow it.
476  tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
479  return true;
480  }
481  if (!fp->DeSerialize(&prev_sample_iteration_)) return false;
482  if (!fp->DeSerialize(&perfect_delay_)) return false;
483  if (!fp->DeSerialize(&last_perfect_training_iteration_)) return false;
484  for (auto & error_buffer : error_buffers_) {
485  if (!error_buffer.DeSerialize(fp)) return false;
486  }
487  if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) return false;
488  if (!fp->DeSerialize(&training_stage_)) return false;
489  uint8_t amount;
490  if (!fp->DeSerialize(&amount)) return false;
491  if (amount == LIGHT) return true; // Don't read the rest.
492  if (!fp->DeSerialize(&best_error_rate_)) return false;
493  if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
494  if (!fp->DeSerialize(&best_iteration_)) return false;
495  if (!fp->DeSerialize(&worst_error_rate_)) return false;
496  if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
497  if (!fp->DeSerialize(&worst_iteration_)) return false;
498  if (!fp->DeSerialize(&stall_iteration_)) return false;
499  if (!best_model_data_.DeSerialize(fp)) return false;
500  if (!worst_model_data_.DeSerialize(fp)) return false;
501  if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
502  GenericVector<char> sub_data;
503  if (!sub_data.DeSerialize(fp)) return false;
504  delete sub_trainer_;
505  if (sub_data.empty()) {
506  sub_trainer_ = nullptr;
507  } else {
508  sub_trainer_ = new LSTMTrainer();
509  if (!ReadTrainingDump(sub_data, sub_trainer_)) return false;
510  }
511  if (!best_error_history_.DeSerialize(fp)) return false;
512  if (!best_error_iterations_.DeSerialize(fp)) return false;
513  return fp->DeSerialize(&improvement_steps_);
514 }
515 
516 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
517 // learning rates (by scaling reduction, or layer specific, according to
518 // NF_LAYER_SPECIFIC_LR).
520  delete sub_trainer_;
521  sub_trainer_ = new LSTMTrainer();
523  *log_msg += " Failed to revert to previous best for trial!";
524  delete sub_trainer_;
525  sub_trainer_ = nullptr;
526  } else {
527  log_msg->add_str_int(" Trial sub_trainer_ from iteration ",
529  // Reduce learning rate so it doesn't diverge this time.
530  sub_trainer_->ReduceLearningRates(this, log_msg);
531  // If it fails again, we will wait twice as long before reverting again.
532  int stall_offset =
534  stall_iteration_ = learning_iteration() + 2 * stall_offset;
536  // Re-save the best trainer with the new learning rates and stall iteration.
538  }
539 }
540 
541 // While the sub_trainer_ is behind the current training iteration and its
542 // training error is at least kSubTrainerMarginFraction better than the
543 // current training error, trains the sub_trainer_, and returns STR_UPDATED if
544 // it did anything. If it catches up, and has a better error rate than the
545 // current best, as well as a margin over the current error rate, then the
546 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
547 // returned. STR_NONE is returned if the subtrainer wasn't good enough to
548 // receive any training iterations.
550  double training_error = CharError();
551  double sub_error = sub_trainer_->CharError();
552  double sub_margin = (training_error - sub_error) / sub_error;
553  if (sub_margin >= kSubTrainerMarginFraction) {
554  log_msg->add_str_double(" sub_trainer=", sub_error);
555  log_msg->add_str_double(" margin=", 100.0 * sub_margin);
556  *log_msg += "\n";
557  // Catch up to current iteration.
558  int end_iteration = training_iteration();
559  while (sub_trainer_->training_iteration() < end_iteration &&
560  sub_margin >= kSubTrainerMarginFraction) {
561  int target_iteration =
563  while (sub_trainer_->training_iteration() < target_iteration) {
564  sub_trainer_->TrainOnLine(this, false);
565  }
566  STRING batch_log = "Sub:";
567  sub_trainer_->PrepareLogMsg(&batch_log);
568  batch_log += "\n";
569  tprintf("UpdateSubtrainer:%s", batch_log.string());
570  *log_msg += batch_log;
571  sub_error = sub_trainer_->CharError();
572  sub_margin = (training_error - sub_error) / sub_error;
573  }
574  if (sub_error < best_error_rate_ &&
575  sub_margin >= kSubTrainerMarginFraction) {
576  // The sub_trainer_ has won the race to a new best. Switch to it.
577  GenericVector<char> updated_trainer;
578  SaveTrainingDump(LIGHT, sub_trainer_, &updated_trainer);
579  ReadTrainingDump(updated_trainer, this);
580  log_msg->add_str_int(" Sub trainer wins at iteration ",
582  *log_msg += "\n";
583  return STR_REPLACED;
584  }
585  return STR_UPDATED;
586  }
587  return STR_NONE;
588 }
589 
590 // Reduces network learning rates, either for everything, or for layers
591 // independently, according to NF_LAYER_SPECIFIC_LR.
593  STRING* log_msg) {
595  int num_reduced = ReduceLayerLearningRates(
596  kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
597  log_msg->add_str_int("\nReduced learning rate on layers: ", num_reduced);
598  } else {
600  log_msg->add_str_double("\nReduced learning rate to :", learning_rate_);
601  }
602  *log_msg += "\n";
603 }
604 
605 // Considers reducing the learning rate independently for each layer down by
606 // factor(<1), or leaving it the same, by double-training the given number of
607 // samples and minimizing the amount of changing of sign of weight updates.
608 // Even if it looks like all weights should remain the same, an adjustment
609 // will be made to guarantee a different result when reverting to an old best.
610 // Returns the number of layer learning rates that were reduced.
611 int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
612  LSTMTrainer* samples_trainer) {
613  enum WhichWay {
614  LR_DOWN, // Learning rate will go down by factor.
615  LR_SAME, // Learning rate will stay the same.
616  LR_COUNT // Size of arrays.
617  };
619  int num_layers = layers.size();
620  GenericVector<int> num_weights;
621  num_weights.init_to_size(num_layers, 0);
622  GenericVector<double> bad_sums[LR_COUNT];
623  GenericVector<double> ok_sums[LR_COUNT];
624  for (int i = 0; i < LR_COUNT; ++i) {
625  bad_sums[i].init_to_size(num_layers, 0.0);
626  ok_sums[i].init_to_size(num_layers, 0.0);
627  }
628  double momentum_factor = 1.0 / (1.0 - momentum_);
629  GenericVector<char> orig_trainer;
630  samples_trainer->SaveTrainingDump(LIGHT, this, &orig_trainer);
631  for (int i = 0; i < num_layers; ++i) {
632  Network* layer = GetLayer(layers[i]);
633  num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
634  }
635  int iteration = sample_iteration();
636  for (int s = 0; s < num_samples; ++s) {
637  // Which way will we modify the learning rate?
638  for (int ww = 0; ww < LR_COUNT; ++ww) {
639  // Transfer momentum to learning rate and adjust by the ww factor.
640  float ww_factor = momentum_factor;
641  if (ww == LR_DOWN) ww_factor *= factor;
642  // Make a copy of *this, so we can mess about without damaging anything.
643  LSTMTrainer copy_trainer;
644  samples_trainer->ReadTrainingDump(orig_trainer, &copy_trainer);
645  // Clear the updates, doing nothing else.
646  copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
647  // Adjust the learning rate in each layer.
648  for (int i = 0; i < num_layers; ++i) {
649  if (num_weights[i] == 0) continue;
650  copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
651  }
652  copy_trainer.SetIteration(iteration);
653  // Train on the sample, but keep the update in updates_ instead of
654  // applying to the weights.
655  const ImageData* trainingdata =
656  copy_trainer.TrainOnLine(samples_trainer, true);
657  if (trainingdata == nullptr) continue;
658  // We'll now use this trainer again for each layer.
659  GenericVector<char> updated_trainer;
660  samples_trainer->SaveTrainingDump(LIGHT, &copy_trainer, &updated_trainer);
661  for (int i = 0; i < num_layers; ++i) {
662  if (num_weights[i] == 0) continue;
663  LSTMTrainer layer_trainer;
664  samples_trainer->ReadTrainingDump(updated_trainer, &layer_trainer);
665  Network* layer = layer_trainer.GetLayer(layers[i]);
666  // Update the weights in just the layer, using Adam if enabled.
667  layer->Update(0.0, momentum_, adam_beta_,
668  layer_trainer.training_iteration_ + 1);
669  // Zero the updates matrix again.
670  layer->Update(0.0, 0.0, 0.0, 0);
671  // Train again on the same sample, again holding back the updates.
672  layer_trainer.TrainOnLine(trainingdata, true);
673  // Count the sign changes in the updates in layer vs in copy_trainer.
674  float before_bad = bad_sums[ww][i];
675  float before_ok = ok_sums[ww][i];
676  layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
677  &ok_sums[ww][i], &bad_sums[ww][i]);
678  float bad_frac =
679  bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
680  if (bad_frac > 0.0f)
681  bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
682  }
683  }
684  ++iteration;
685  }
686  int num_lowered = 0;
687  for (int i = 0; i < num_layers; ++i) {
688  if (num_weights[i] == 0) continue;
689  Network* layer = GetLayer(layers[i]);
690  float lr = GetLayerLearningRate(layers[i]);
691  double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
692  double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
693  double frac_down = bad_sums[LR_DOWN][i] / total_down;
694  double frac_same = bad_sums[LR_SAME][i] / total_same;
695  tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().string(),
696  lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
697  if (frac_down < frac_same * kImprovementFraction) {
698  tprintf(" REDUCED\n");
699  ScaleLayerLearningRate(layers[i], factor);
700  ++num_lowered;
701  } else {
702  tprintf(" SAME\n");
703  }
704  }
705  if (num_lowered == 0) {
706  // Just lower everything to make sure.
707  for (int i = 0; i < num_layers; ++i) {
708  if (num_weights[i] > 0) {
709  ScaleLayerLearningRate(layers[i], factor);
710  ++num_lowered;
711  }
712  }
713  }
714  return num_lowered;
715 }
716 
717 // Converts the string to integer class labels, with appropriate null_char_s
718 // in between if not in SimpleTextOutput mode. Returns false on failure.
719 /* static */
720 bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset,
721  const UnicharCompress* recoder, bool simple_text,
722  int null_char, GenericVector<int>* labels) {
723  if (str.string() == nullptr || str.length() <= 0) {
724  tprintf("Empty truth string!\n");
725  return false;
726  }
727  int err_index;
728  GenericVector<int> internal_labels;
729  labels->truncate(0);
730  if (!simple_text) labels->push_back(null_char);
731  std::string cleaned = unicharset.CleanupString(str.string());
732  if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
733  &err_index)) {
734  bool success = true;
735  for (int i = 0; i < internal_labels.size(); ++i) {
736  if (recoder != nullptr) {
737  // Re-encode labels via recoder.
738  RecodedCharID code;
739  int len = recoder->EncodeUnichar(internal_labels[i], &code);
740  if (len > 0) {
741  for (int j = 0; j < len; ++j) {
742  labels->push_back(code(j));
743  if (!simple_text) labels->push_back(null_char);
744  }
745  } else {
746  success = false;
747  err_index = 0;
748  break;
749  }
750  } else {
751  labels->push_back(internal_labels[i]);
752  if (!simple_text) labels->push_back(null_char);
753  }
754  }
755  if (success) return true;
756  }
757  tprintf("Encoding of string failed! Failure bytes:");
758  while (err_index < cleaned.size()) {
759  tprintf(" %x", cleaned[err_index++]);
760  }
761  tprintf("\n");
762  return false;
763 }
764 
765 // Performs forward-backward on the given trainingdata.
766 // Returns a Trainability enum to indicate the suitability of the sample.
768  bool batch) {
769  NetworkIO fwd_outputs, targets;
770  Trainability trainable =
771  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
773  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
774  return trainable; // Sample was unusable.
775  }
776  bool debug = debug_interval_ > 0 &&
778  // Run backprop on the output.
779  NetworkIO bp_deltas;
780  if (network_->IsTraining() &&
781  (trainable != PERFECT ||
784  network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
786  training_iteration_ + 1);
787  }
788 #ifndef GRAPHICS_DISABLED
789  if (debug_interval_ == 1 && debug_win_ != nullptr) {
791  }
792 #endif // GRAPHICS_DISABLED
793  // Roll the memory of past means.
795  return trainable;
796 }
797 
798 // Prepares the ground truth, runs forward, and prepares the targets.
799 // Returns a Trainability enum to indicate the suitability of the sample.
801  NetworkIO* fwd_outputs,
802  NetworkIO* targets) {
803  if (trainingdata == nullptr) {
804  tprintf("Null trainingdata.\n");
805  return UNENCODABLE;
806  }
807  // Ensure repeatability of random elements even across checkpoints.
808  bool debug = debug_interval_ > 0 &&
810  GenericVector<int> truth_labels;
811  if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
812  tprintf("Can't encode transcription: '%s' in language '%s'\n",
813  trainingdata->transcription().string(),
814  trainingdata->language().string());
815  return UNENCODABLE;
816  }
817  bool upside_down = false;
818  if (randomly_rotate_) {
819  // This ensures consistent training results.
820  SetRandomSeed();
821  upside_down = randomizer_.SignedRand(1.0) > 0.0;
822  if (upside_down) {
823  // Modify the truth labels to match the rotation:
824  // Apart from space and null, increment the label. This is changes the
825  // script-id to the same script-id but upside-down.
826  // The labels need to be reversed in order, as the first is now the last.
827  for (int c = 0; c < truth_labels.size(); ++c) {
828  if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_)
829  ++truth_labels[c];
830  }
831  truth_labels.reverse();
832  }
833  }
834  int w = 0;
835  while (w < truth_labels.size() &&
836  (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_))
837  ++w;
838  if (w == truth_labels.size()) {
839  tprintf("Blank transcription: %s\n",
840  trainingdata->transcription().string());
841  return UNENCODABLE;
842  }
843  float image_scale;
844  NetworkIO inputs;
845  bool invert = trainingdata->boxes().empty();
846  if (!RecognizeLine(*trainingdata, invert, debug, invert, upside_down,
847  &image_scale, &inputs, fwd_outputs)) {
848  tprintf("Image not trainable\n");
849  return UNENCODABLE;
850  }
851  targets->Resize(*fwd_outputs, network_->NumOutputs());
852  LossType loss_type = OutputLossType();
853  if (loss_type == LT_SOFTMAX) {
854  if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
855  tprintf("Compute simple targets failed!\n");
856  return UNENCODABLE;
857  }
858  } else if (loss_type == LT_CTC) {
859  if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
860  tprintf("Compute CTC targets failed!\n");
861  return UNENCODABLE;
862  }
863  } else {
864  tprintf("Logistic outputs not implemented yet!\n");
865  return UNENCODABLE;
866  }
867  GenericVector<int> ocr_labels;
868  GenericVector<int> xcoords;
869  LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
870  // CTC does not produce correct target labels to begin with.
871  if (loss_type != LT_CTC) {
872  LabelsFromOutputs(*targets, &truth_labels, &xcoords);
873  }
874  if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
875  *targets)) {
876  tprintf("Input width was %d\n", inputs.Width());
877  return UNENCODABLE;
878  }
879  STRING ocr_text = DecodeLabels(ocr_labels);
880  STRING truth_text = DecodeLabels(truth_labels);
881  targets->SubtractAllFromFloat(*fwd_outputs);
882  if (debug_interval_ != 0) {
883  if (truth_text != ocr_text) {
884  tprintf("Iteration %d: BEST OCR TEXT : %s\n",
885  training_iteration(), ocr_text.string());
886  }
887  }
888  double char_error = ComputeCharError(truth_labels, ocr_labels);
889  double word_error = ComputeWordError(&truth_text, &ocr_text);
890  double delta_error = ComputeErrorRates(*targets, char_error, word_error);
891  if (debug_interval_ != 0) {
892  tprintf("File %s line %d %s:\n", trainingdata->imagefilename().string(),
893  trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
894  }
895  if (delta_error == 0.0) return PERFECT;
897  return TRAINABLE;
898 }
899 
900 // Writes the trainer to memory, so that the current training state can be
901 // restored. *this must always be the master trainer that retains the only
902 // copy of the training data and language model. trainer is the model that is
903 // actually serialized.
905  const LSTMTrainer* trainer,
906  GenericVector<char>* data) const {
907  TFile fp;
908  fp.OpenWrite(data);
909  return trainer->Serialize(serialize_amount, &mgr_, &fp);
910 }
911 
912 // Restores the model to *this.
914  const char* data, int size) {
915  if (size == 0) {
916  tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
917  return false;
918  }
919  TFile fp;
920  fp.Open(data, size);
921  return DeSerialize(mgr, &fp);
922 }
923 
924 // Writes the full recognition traineddata to the given filename.
925 bool LSTMTrainer::SaveTraineddata(const STRING& filename) {
926  GenericVector<char> recognizer_data;
927  SaveRecognitionDump(&recognizer_data);
928  mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
929  recognizer_data.size());
930  return mgr_.SaveFile(filename, file_writer_);
931 }
932 
933 // Writes the recognizer to memory, so that it can be used for testing later.
935  TFile fp;
936  fp.OpenWrite(data);
940 }
941 
942 // Returns a suitable filename for a training dump, based on the model_base_,
943 // the iteration and the error rates.
945  STRING filename;
947  filename.add_str_int("_", best_iteration_);
948  filename += ".checkpoint";
949  return filename;
950 }
951 
952 // Fills the whole error buffer of the given type with the given value.
954  for (int i = 0; i < kRollingBufferSize_; ++i)
955  error_buffers_[type][i] = new_error;
956  error_rates_[type] = 100.0 * new_error;
957 }
958 
959 // Helper generates a map from each current recoder_ code (ie softmax index)
960 // to the corresponding old_recoder code, or -1 if there isn't one.
961 std::vector<int> LSTMTrainer::MapRecoder(
962  const UNICHARSET& old_chset, const UnicharCompress& old_recoder) const {
963  int num_new_codes = recoder_.code_range();
964  int num_new_unichars = GetUnicharset().size();
965  std::vector<int> code_map(num_new_codes, -1);
966  for (int c = 0; c < num_new_codes; ++c) {
967  int old_code = -1;
968  // Find all new unichar_ids that recode to something that includes c.
969  // The <= is to include the null char, which may be beyond the unicharset.
970  for (int uid = 0; uid <= num_new_unichars; ++uid) {
971  RecodedCharID codes;
972  int length = recoder_.EncodeUnichar(uid, &codes);
973  int code_index = 0;
974  while (code_index < length && codes(code_index) != c) ++code_index;
975  if (code_index == length) continue;
976  // The old unicharset must have the same unichar.
977  int old_uid =
978  uid < num_new_unichars
979  ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
980  : old_chset.size() - 1;
981  if (old_uid == INVALID_UNICHAR_ID) continue;
982  // The encoding of old_uid at the same code_index is the old code.
983  RecodedCharID old_codes;
984  if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
985  old_code = old_codes(code_index);
986  break;
987  }
988  }
989  code_map[c] = old_code;
990  }
991  return code_map;
992 }
993 
994 // Private version of InitCharSet above finishes the job after initializing
995 // the mgr_ data member.
999  // Initialize the unicharset and recoder.
1000  if (!LoadCharsets(&mgr_)) {
1001  ASSERT_HOST(
1002  "Must provide a traineddata containing lstm_unicharset and"
1003  " lstm_recoder!\n" != nullptr);
1004  }
1005  SetNullChar();
1006 }
1007 
1008 // Helper computes and sets the null_char_.
1011  : GetUnicharset().size();
1012  RecodedCharID code;
1014  null_char_ = code(0);
1015 }
1016 
1017 // Factored sub-constructor sets up reasonable default values.
1019  align_win_ = nullptr;
1020  target_win_ = nullptr;
1021  ctc_win_ = nullptr;
1022  recon_win_ = nullptr;
1024  training_stage_ = 0;
1026  InitIterations();
1027 }
1028 
1029 // Outputs the string and periodically displays the given network inputs
1030 // as an image in the given window, and the corresponding labels at the
1031 // corresponding x_starts.
1032 // Returns false if the truth string is empty.
1034  const ImageData& trainingdata,
1035  const NetworkIO& fwd_outputs,
1036  const GenericVector<int>& truth_labels,
1037  const NetworkIO& outputs) {
1038  const STRING& truth_text = DecodeLabels(truth_labels);
1039  if (truth_text.string() == nullptr || truth_text.length() <= 0) {
1040  tprintf("Empty truth string at decode time!\n");
1041  return false;
1042  }
1043  if (debug_interval_ != 0) {
1044  // Get class labels, xcoords and string.
1045  GenericVector<int> labels;
1046  GenericVector<int> xcoords;
1047  LabelsFromOutputs(outputs, &labels, &xcoords);
1048  STRING text = DecodeLabels(labels);
1049  tprintf("Iteration %d: GROUND TRUTH : %s\n",
1050  training_iteration(), truth_text.string());
1051  if (truth_text != text) {
1052  tprintf("Iteration %d: ALIGNED TRUTH : %s\n",
1053  training_iteration(), text.string());
1054  }
1055  if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) {
1056  tprintf("TRAINING activation path for truth string %s\n",
1057  truth_text.string());
1058  DebugActivationPath(outputs, labels, xcoords);
1059  DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1060  if (OutputLossType() == LT_CTC) {
1061  DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1062  DisplayTargets(outputs, "CTC Targets", &target_win_);
1063  }
1064  }
1065  }
1066  return true;
1067 }
1068 
1069 // Displays the network targets as line a line graph.
1071  const char* window_name, ScrollView** window) {
1072 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics.
1073  int width = targets.Width();
1074  int num_features = targets.NumFeatures();
1075  Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1076  window);
1077  for (int c = 0; c < num_features; ++c) {
1078  int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1079  (*window)->Pen(static_cast<ScrollView::Color>(color));
1080  int start_t = -1;
1081  for (int t = 0; t < width; ++t) {
1082  double target = targets.f(t)[c];
1083  target *= kTargetYScale;
1084  if (target >= 1) {
1085  if (start_t < 0) {
1086  (*window)->SetCursor(t - 1, 0);
1087  start_t = t;
1088  }
1089  (*window)->DrawTo(t, target);
1090  } else if (start_t >= 0) {
1091  (*window)->DrawTo(t, 0);
1092  (*window)->DrawTo(start_t - 1, 0);
1093  start_t = -1;
1094  }
1095  }
1096  if (start_t >= 0) {
1097  (*window)->DrawTo(width, 0);
1098  (*window)->DrawTo(start_t - 1, 0);
1099  }
1100  }
1101  (*window)->Update();
1102 #endif // GRAPHICS_DISABLED
1103 }
1104 
1105 // Builds a no-compromises target where the first positions should be the
1106 // truth labels and the rest is padded with the null_char_.
1108  const GenericVector<int>& truth_labels,
1109  NetworkIO* targets) {
1110  if (truth_labels.size() > targets->Width()) {
1111  tprintf("Error: transcription %s too long to fit into target of width %d\n",
1112  DecodeLabels(truth_labels).string(), targets->Width());
1113  return false;
1114  }
1115  for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) {
1116  targets->SetActivations(i, truth_labels[i], 1.0);
1117  }
1118  for (int i = truth_labels.size(); i < targets->Width(); ++i) {
1119  targets->SetActivations(i, null_char_, 1.0);
1120  }
1121  return true;
1122 }
1123 
1124 // Builds a target using standard CTC. truth_labels should be pre-padded with
1125 // nulls wherever desired. They don't have to be between all labels.
1126 // outputs is input-output, as it gets clipped to minimum probability.
1128  NetworkIO* outputs, NetworkIO* targets) {
1129  // Bottom-clip outputs to a minimum probability.
1130  CTC::NormalizeProbs(outputs);
1131  return CTC::ComputeCTCTargets(truth_labels, null_char_,
1132  outputs->float_array(), targets);
1133 }
1134 
1135 // Computes network errors, and stores the results in the rolling buffers,
1136 // along with the supplied text_error.
1137 // Returns the delta error of the current sample (not running average.)
1139  double char_error, double word_error) {
1141  // Delta error is the fraction of timesteps with >0.5 error in the top choice
1142  // score. If zero, then the top choice characters are guaranteed correct,
1143  // even when there is residue in the RMS error.
1144  double delta_error = ComputeWinnerError(deltas);
1145  UpdateErrorBuffer(delta_error, ET_DELTA);
1146  UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1147  UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1148  // Skip ratio measures the difference between sample_iteration_ and
1149  // training_iteration_, which reflects the number of unusable samples,
1150  // usually due to unencodable truth text, or the text not fitting in the
1151  // space for the output.
1152  double skip_count = sample_iteration_ - prev_sample_iteration_;
1153  UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1154  return delta_error;
1155 }
1156 
1157 // Computes the network activation RMS error rate.
1159  double total_error = 0.0;
1160  int width = deltas.Width();
1161  int num_classes = deltas.NumFeatures();
1162  for (int t = 0; t < width; ++t) {
1163  const float* class_errs = deltas.f(t);
1164  for (int c = 0; c < num_classes; ++c) {
1165  double error = class_errs[c];
1166  total_error += error * error;
1167  }
1168  }
1169  return sqrt(total_error / (width * num_classes));
1170 }
1171 
1172 // Computes network activation winner error rate. (Number of values that are
1173 // in error by >= 0.5 divided by number of time-steps.) More closely related
1174 // to final character error than RMS, but still directly calculable from
1175 // just the deltas. Because of the binary nature of the targets, zero winner
1176 // error is a sufficient but not necessary condition for zero char error.
1178  int num_errors = 0;
1179  int width = deltas.Width();
1180  int num_classes = deltas.NumFeatures();
1181  for (int t = 0; t < width; ++t) {
1182  const float* class_errs = deltas.f(t);
1183  for (int c = 0; c < num_classes; ++c) {
1184  float abs_delta = fabs(class_errs[c]);
1185  // TODO(rays) Filtering cases where the delta is very large to cut out
1186  // GT errors doesn't work. Find a better way or get better truth.
1187  if (0.5 <= abs_delta)
1188  ++num_errors;
1189  }
1190  }
1191  return static_cast<double>(num_errors) / width;
1192 }
1193 
1194 // Computes a very simple bag of chars char error rate.
1196  const GenericVector<int>& ocr_str) {
1197  GenericVector<int> label_counts;
1198  label_counts.init_to_size(NumOutputs(), 0);
1199  int truth_size = 0;
1200  for (int i = 0; i < truth_str.size(); ++i) {
1201  if (truth_str[i] != null_char_) {
1202  ++label_counts[truth_str[i]];
1203  ++truth_size;
1204  }
1205  }
1206  for (int i = 0; i < ocr_str.size(); ++i) {
1207  if (ocr_str[i] != null_char_) {
1208  --label_counts[ocr_str[i]];
1209  }
1210  }
1211  int char_errors = 0;
1212  for (int i = 0; i < label_counts.size(); ++i) {
1213  char_errors += abs(label_counts[i]);
1214  }
1215  if (truth_size == 0) {
1216  return (char_errors == 0) ? 0.0 : 1.0;
1217  }
1218  return static_cast<double>(char_errors) / truth_size;
1219 }
1220 
1221 // Computes word recall error rate using a very simple bag of words algorithm.
1222 // NOTE that this is destructive on both input strings.
1223 double LSTMTrainer::ComputeWordError(STRING* truth_str, STRING* ocr_str) {
1224  using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1225  GenericVector<STRING> truth_words, ocr_words;
1226  truth_str->split(' ', &truth_words);
1227  if (truth_words.empty()) return 0.0;
1228  ocr_str->split(' ', &ocr_words);
1229  StrMap word_counts;
1230  for (int i = 0; i < truth_words.size(); ++i) {
1231  std::string truth_word(truth_words[i].string());
1232  auto it = word_counts.find(truth_word);
1233  if (it == word_counts.end())
1234  word_counts.insert(std::make_pair(truth_word, 1));
1235  else
1236  ++it->second;
1237  }
1238  for (int i = 0; i < ocr_words.size(); ++i) {
1239  std::string ocr_word(ocr_words[i].string());
1240  auto it = word_counts.find(ocr_word);
1241  if (it == word_counts.end())
1242  word_counts.insert(std::make_pair(ocr_word, -1));
1243  else
1244  --it->second;
1245  }
1246  int word_recall_errs = 0;
1247  for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1248  ++it) {
1249  if (it->second > 0) word_recall_errs += it->second;
1250  }
1251  return static_cast<double>(word_recall_errs) / truth_words.size();
1252 }
1253 
1254 // Updates the error buffer and corresponding mean of the given type with
1255 // the new_error.
1258  error_buffers_[type][index] = new_error;
1259  // Compute the mean error.
1260  int mean_count = std::min(training_iteration_ + 1, error_buffers_[type].size());
1261  double buffer_sum = 0.0;
1262  for (int i = 0; i < mean_count; ++i) buffer_sum += error_buffers_[type][i];
1263  double mean = buffer_sum / mean_count;
1264  // Trim precision to 1/1000 of 1%.
1265  error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1266 }
1267 
1268 // Rolls error buffers and reports the current means.
1271  if (NewSingleError(ET_DELTA) > 0.0)
1273  else
1276  if (debug_interval_ != 0) {
1277  tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1281  }
1282 }
1283 
1284 // Given that error_rate is either a new min or max, updates the best/worst
1285 // error rates, and record of progress.
1286 // Tester is an externally supplied callback function that tests on some
1287 // data set with a given model and records the error rates in a graph.
1288 STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate,
1289  const GenericVector<char>& model_data,
1290  TestCallback tester) {
1291  if (error_rate > best_error_rate_
1292  && iteration < best_iteration_ + kErrorGraphInterval) {
1293  // Too soon to record a new point.
1294  if (tester != nullptr && !worst_model_data_.empty()) {
1297  return tester->Run(worst_iteration_, nullptr, mgr_, CurrentTrainingStage());
1298  } else {
1299  return "";
1300  }
1301  }
1302  STRING result;
1303  // NOTE: there are 2 asymmetries here:
1304  // 1. We are computing the global minimum, but the local maximum in between.
1305  // 2. If the tester returns an empty string, indicating that it is busy,
1306  // call it repeatedly on new local maxima to test the previous min, but
1307  // not the other way around, as there is little point testing the maxima
1308  // between very frequent minima.
1309  if (error_rate < best_error_rate_) {
1310  // This is a new (global) minimum.
1311  if (tester != nullptr && !worst_model_data_.empty()) {
1314  result = tester->Run(worst_iteration_, worst_error_rates_, mgr_,
1317  best_model_data_ = model_data;
1318  }
1319  best_error_rate_ = error_rate;
1320  memcpy(best_error_rates_, error_rates_, sizeof(error_rates_));
1321  best_iteration_ = iteration;
1322  best_error_history_.push_back(error_rate);
1323  best_error_iterations_.push_back(iteration);
1324  // Compute 2% decay time.
1325  double two_percent_more = error_rate + 2.0;
1326  int i;
1327  for (i = best_error_history_.size() - 1;
1328  i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1329  }
1330  int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1331  improvement_steps_ = iteration - old_iteration;
1332  tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1333  improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1334  old_iteration);
1335  } else if (error_rate > best_error_rate_) {
1336  // This is a new (local) maximum.
1337  if (tester != nullptr) {
1338  if (!best_model_data_.empty()) {
1341  result = tester->Run(best_iteration_, best_error_rates_, mgr_,
1343  } else if (!worst_model_data_.empty()) {
1344  // Allow for multiple data points with "worst" error rate.
1347  result = tester->Run(worst_iteration_, worst_error_rates_, mgr_,
1349  }
1350  if (result.length() > 0)
1352  worst_model_data_ = model_data;
1353  }
1354  }
1355  worst_error_rate_ = error_rate;
1356  memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_));
1357  worst_iteration_ = iteration;
1358  return result;
1359 }
1360 
1361 } // namespace tesseract.
LossType OutputLossType() const
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
void ScaleLearningRate(double factor)
void UpdateErrorBuffer(double new_error, ErrorTypes type)
bool empty() const
Definition: genericvector.h:91
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:258
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
void split(char c, GenericVector< STRING > *splited)
Definition: strngs.cpp:282
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:55
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:46
bool(*)(const GenericVector< char > &, const STRING &) FileWriter
Definition: serialis.h:52
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:291
const STRING & imagefilename() const
Definition: imagedata.h:126
std::string VersionString() const
void OpenWrite(GenericVector< char > *data)
Definition: serialis.cpp:296
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
int size() const
Definition: unicharset.h:341
void init_to_size(int size, const T &t)
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
virtual STRING spec() const
Definition: network.h:141
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:57
void FillErrorBuffer(double new_error, ErrorTypes type)
void StartSubtrainer(STRING *log_msg)
static constexpr float kMinCertainty
Definition: recodebeam.h:222
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:259
bool GetComponent(TessdataType type, TFile *fp)
bool TransitionTrainingStage(float error_threshold)
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
bool SaveFile(const STRING &filename, FileWriter writer) const
void SetActivations(int t, int label, float ok_score)
Definition: networkio.cpp:537
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:59
bool SaveTraineddata(const STRING &filename)
int num_weights() const
Definition: network.h:119
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
void ScaleLayerLearningRate(const STRING &id, double factor)
const STRING & name() const
Definition: network.h:138
SVEvent * AwaitEvent(SVEventType type)
Definition: scrollview.cpp:443
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
double ComputeRMSError(const NetworkIO &deltas)
const UNICHARSET & GetUnicharset() const
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:61
ScrollView * ctc_win_
Definition: lstmtrainer.h:401
bool Open(const STRING &filename, FileReader reader)
Definition: serialis.cpp:197
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
int Width() const
Definition: networkio.h:107
const char * c_str() const
Definition: strngs.cpp:205
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
Definition: recodebeam.cpp:133
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:458
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
void OverwriteEntry(TessdataType type, const char *data, int size)
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:445
float GetLayerLearningRate(const STRING &id) const
ScrollView * target_win_
Definition: lstmtrainer.h:399
virtual R Run(A1, A2)=0
const int kTargetYScale
Definition: lstmtrainer.cpp:72
const char * string() const
Definition: strngs.cpp:194
bool LoadCharsets(const TessdataManager *mgr)
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:51
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:312
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:457
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154
virtual void DebugWeights()=0
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:580
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
Definition: network.h:230
virtual void CountAlternators(const Network &other, double *same, double *changed) const
Definition: network.h:235
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
void truncate(int size)
constexpr size_t countof(T const (&)[N]) noexcept
Definition: serialis.h:43
double CharError() const
Definition: lstmtrainer.h:139
int32_t length() const
Definition: strngs.cpp:189
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
int NumFeatures() const
Definition: networkio.h:111
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
CachingStrategy
Definition: imagedata.h:42
bool Init(const char *data_file_name)
STRING DecodeLabels(const GenericVector< int > &labels)
bool DeSerialize(bool swap, FILE *fp)
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:63
bool(*)(const STRING &, GenericVector< char > *) FileReader
Definition: serialis.h:49
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:104
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:438
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
virtual R Run(A1, A2, A3, A4)=0
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:246
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
int learning_iteration() const
Definition: lstmtrainer.h:149
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
Definition: unicharset.cpp:210
void SaveRecognitionDump(GenericVector< char > *data) const
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
bool has_special_codes() const
Definition: unicharset.h:722
Network * GetLayer(const STRING &id) const
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
const double kLearningRateDecay
Definition: lstmtrainer.cpp:53
bool Serialize(const TessdataManager *mgr, TFile *fp) const
void add_str_int(const char *str, int number)
Definition: strngs.cpp:377
const GenericVector< TBOX > & boxes() const
Definition: imagedata.h:150
const int kMinStallIterations
Definition: lstmtrainer.cpp:48
TessdataManager mgr_
Definition: lstmtrainer.h:483
void PrepareLogMsg(STRING *log_msg) const
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:148
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
ScrollView * recon_win_
Definition: lstmtrainer.h:403
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:444
double ComputeWinnerError(const NetworkIO &deltas)
void SetVersionString(const std::string &v_str)
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:452
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
double learning_rate() const
Definition: strngs.h:45
const double kImprovementFraction
Definition: lstmtrainer.cpp:67
void LogIterations(const char *intro_str, STRING *log_msg) const
STRING DumpFilename() const
_ConstTessMemberResultCallback_5_0< false, R, T1, P1, P2, P3, P4, P5 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)(P1, P2, P3, P4, P5) const, typename Identity< P1 >::type p1, typename Identity< P2 >::type p2, typename Identity< P3 >::type p3, typename Identity< P4 >::type p4, typename Identity< P5 >::type p5)
Definition: tesscallback.h:258
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:69
double SignedRand(double range)
Definition: helpers.h:55
float * f(int t)
Definition: networkio.h:115
const int kTargetXScale
Definition: lstmtrainer.cpp:71
GenericVector< STRING > EnumerateLayers() const
int page_number() const
Definition: imagedata.h:132
const STRING & transcription() const
Definition: imagedata.h:147
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
DocumentCache training_data_
Definition: lstmtrainer.h:414
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:76
const double kHighConfidence
Definition: lstmtrainer.cpp:65
virtual StaticShape InputShape() const
Definition: network.h:127
const STRING & language() const
Definition: imagedata.h:141
int CurrentTrainingStage() const
Definition: lstmtrainer.h:211
int push_back(T object)
static std::string CleanupString(const char *utf8_str)
Definition: unicharset.h:246
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447
int size() const
Definition: genericvector.h:72
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
int NumOutputs() const
Definition: network.h:123
virtual R Run(A1, A2, A3)=0
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
Definition: unicharset.cpp:259
void SetIteration(int iteration)
ScrollView * align_win_
Definition: lstmtrainer.h:397
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
void SubtractAllFromFloat(const NetworkIO &src)
Definition: networkio.cpp:824
NetworkScratch scratch_space_
bool IsTraining() const
Definition: network.h:115
bool Serialize(FILE *fp) const
const GENERIC_2D_ARRAY< float > & float_array() const
Definition: networkio.h:139
void add_str_double(const char *str, double number)
Definition: strngs.cpp:387
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
#define ASSERT_HOST(x)
Definition: errcode.h:88
int InitTensorFlowNetwork(const std::string &tf_proto)
int IntCastRounded(double x)
Definition: helpers.h:175
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
Definition: network.h:186
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:388
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:54
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool AnySuspiciousTruth(float confidence_thr) const
Definition: networkio.cpp:579