tesseract  4.1.0
lstmtraining.cpp
Go to the documentation of this file.
1 // File: lstmtraining.cpp
3 // Description: Training program for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifdef GOOGLE_TESSERACT
19 #include "base/commandlineflags.h"
20 #endif
21 #include <cerrno>
22 #include "commontraining.h"
23 #include "lstmtester.h"
24 #include "lstmtrainer.h"
25 #include "params.h"
26 #include "strngs.h"
27 #include "tprintf.h"
29 
30 static INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.");
31 static STRING_PARAM_FLAG(net_spec, "", "Network specification");
32 static INT_PARAM_FLAG(net_mode, 192, "Controls network behavior.");
33 static INT_PARAM_FLAG(perfect_sample_delay, 0,
34  "How many imperfect samples between perfect ones.");
35 static DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
36 static DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
37 static DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas.");
38 static DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas.");
39 static DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas.");
40 static INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
41 static STRING_PARAM_FLAG(continue_from, "", "Existing model to extend");
42 static STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
43 static STRING_PARAM_FLAG(train_listfile, "",
44  "File listing training files in lstmf training format.");
45 static STRING_PARAM_FLAG(eval_listfile, "",
46  "File listing eval files in lstmf training format.");
47 static BOOL_PARAM_FLAG(stop_training, false,
48  "Just convert the training model to a runtime model.");
49 static BOOL_PARAM_FLAG(convert_to_int, false,
50  "Convert the recognition model to an integer model.");
51 static BOOL_PARAM_FLAG(sequential_training, false,
52  "Use the training files sequentially instead of round-robin.");
53 static INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to"
54  " attach the new network defined by net_spec");
55 static BOOL_PARAM_FLAG(debug_network, false,
56  "Get info on distribution of weight values");
57 static INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations");
58 static STRING_PARAM_FLAG(traineddata, "",
59  "Combined Dawgs/Unicharset/Recoder for language model");
60 static STRING_PARAM_FLAG(old_traineddata, "",
61  "When changing the character set, this specifies the old"
62  " character set that is to be replaced");
63 static BOOL_PARAM_FLAG(randomly_rotate, false,
64  "Train OSD and randomly turn training samples upside-down");
65 
66 // Number of training images to train between calls to MaintainCheckpoints.
67 const int kNumPagesPerBatch = 100;
68 
69 // Apart from command-line flags, input is a collection of lstmf files, that
70 // were previously created using tesseract with the lstm.train config file.
71 // The program iterates over the inputs, feeding the data to the network,
72 // until the error rate reaches a specified target or max_iterations is reached.
73 int main(int argc, char **argv) {
74  tesseract::CheckSharedLibraryVersion();
75  ParseArguments(&argc, &argv);
76  if (FLAGS_model_output.empty()) {
77  tprintf("Must provide a --model_output!\n");
78  return EXIT_FAILURE;
79  }
80  if (FLAGS_traineddata.empty()) {
81  tprintf("Must provide a --traineddata see training wiki\n");
82  return EXIT_FAILURE;
83  }
84 
85  // Check write permissions.
86  STRING test_file = FLAGS_model_output.c_str();
87  test_file += "_wtest";
88  FILE* f = fopen(test_file.c_str(), "wb");
89  if (f != nullptr) {
90  fclose(f);
91  if (remove(test_file.c_str()) != 0) {
92  tprintf("Error, failed to remove %s: %s\n",
93  test_file.c_str(), strerror(errno));
94  return EXIT_FAILURE;
95  }
96  } else {
97  tprintf("Error, model output cannot be written: %s\n", strerror(errno));
98  return EXIT_FAILURE;
99  }
100 
101  // Setup the trainer.
102  STRING checkpoint_file = FLAGS_model_output.c_str();
103  checkpoint_file += "_checkpoint";
104  STRING checkpoint_bak = checkpoint_file + ".bak";
105  tesseract::LSTMTrainer trainer(
106  nullptr, nullptr, nullptr, nullptr, FLAGS_model_output.c_str(),
107  checkpoint_file.c_str(), FLAGS_debug_interval,
108  static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
109  trainer.InitCharSet(FLAGS_traineddata.c_str());
110 
111  // Reading something from an existing model doesn't require many flags,
112  // so do it now and exit.
113  if (FLAGS_stop_training || FLAGS_debug_network) {
114  if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) {
115  tprintf("Failed to read continue from: %s\n",
116  FLAGS_continue_from.c_str());
117  return EXIT_FAILURE;
118  }
119  if (FLAGS_debug_network) {
120  trainer.DebugNetwork();
121  } else {
122  if (FLAGS_convert_to_int) trainer.ConvertToInt();
123  if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
124  tprintf("Failed to write recognition model : %s\n",
125  FLAGS_model_output.c_str());
126  }
127  }
128  return EXIT_SUCCESS;
129  }
130 
131  // Get the list of files to process.
132  if (FLAGS_train_listfile.empty()) {
133  tprintf("Must supply a list of training filenames! --train_listfile\n");
134  return EXIT_FAILURE;
135  }
136  GenericVector<STRING> filenames;
137  if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(),
138  &filenames)) {
139  tprintf("Failed to load list of training filenames from %s\n",
140  FLAGS_train_listfile.c_str());
141  return EXIT_FAILURE;
142  }
143 
144  // Checkpoints always take priority if they are available.
145  if (trainer.TryLoadingCheckpoint(checkpoint_file.string(), nullptr) ||
146  trainer.TryLoadingCheckpoint(checkpoint_bak.string(), nullptr)) {
147  tprintf("Successfully restored trainer from %s\n",
148  checkpoint_file.string());
149  } else {
150  if (!FLAGS_continue_from.empty()) {
151  // Load a past model file to improve upon.
152  if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
153  FLAGS_append_index >= 0
154  ? FLAGS_continue_from.c_str()
155  : FLAGS_old_traineddata.c_str())) {
156  tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
157  return EXIT_FAILURE;
158  }
159  tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
160  trainer.InitIterations();
161  }
162  if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
163  if (FLAGS_append_index >= 0) {
164  tprintf("Appending a new network to an old one!!");
165  if (FLAGS_continue_from.empty()) {
166  tprintf("Must set --continue_from for appending!\n");
167  return EXIT_FAILURE;
168  }
169  }
170  // We are initializing from scratch.
171  if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
172  FLAGS_net_mode, FLAGS_weight_range,
173  FLAGS_learning_rate, FLAGS_momentum,
174  FLAGS_adam_beta)) {
175  tprintf("Failed to create network from spec: %s\n",
176  FLAGS_net_spec.c_str());
177  return EXIT_FAILURE;
178  }
179  trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
180  }
181  }
182  if (!trainer.LoadAllTrainingData(filenames,
183  FLAGS_sequential_training
186  FLAGS_randomly_rotate)) {
187  tprintf("Load of images failed!!\n");
188  return EXIT_FAILURE;
189  }
190 
191  tesseract::LSTMTester tester(static_cast<int64_t>(FLAGS_max_image_MB) *
192  1048576);
193  tesseract::TestCallback tester_callback = nullptr;
194  if (!FLAGS_eval_listfile.empty()) {
195  if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
196  tprintf("Failed to load eval data from: %s\n",
197  FLAGS_eval_listfile.c_str());
198  return EXIT_FAILURE;
199  }
200  tester_callback =
202  }
203  do {
204  // Train a few.
205  int iteration = trainer.training_iteration();
206  for (int target_iteration = iteration + kNumPagesPerBatch;
207  iteration < target_iteration &&
208  (iteration < FLAGS_max_iterations || FLAGS_max_iterations == 0);
209  iteration = trainer.training_iteration()) {
210  trainer.TrainOnLine(&trainer, false);
211  }
212  STRING log_str;
213  trainer.MaintainCheckpoints(tester_callback, &log_str);
214  tprintf("%s\n", log_str.string());
215  } while (trainer.best_error_rate() > FLAGS_target_error_rate &&
216  (trainer.training_iteration() < FLAGS_max_iterations ||
217  FLAGS_max_iterations == 0));
218  delete tester_callback;
219  tprintf("Finished! Error rate = %g\n", trainer.best_error_rate());
220  return EXIT_SUCCESS;
221 } /* main */
#define STRING_PARAM_FLAG(name, val, comment)
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
bool LoadFileLinesToStrings(const STRING &filename, GenericVector< STRING > *lines)
Definition: strngs.h:45
bool SaveTraineddata(const STRING &filename)
const int kNumPagesPerBatch
#define DOUBLE_PARAM_FLAG(name, val, comment)
void InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:109
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
STRING RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)
Definition: lstmtester.cpp:52
#define BOOL_PARAM_FLAG(name, val, comment)
double best_error_rate() const
Definition: lstmtrainer.h:143
const char * string() const
Definition: strngs.cpp:194
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:36
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
void ParseArguments(int *argc, char ***argv)
bool LoadAllEvalData(const STRING &filenames_file)
Definition: lstmtester.cpp:30
int training_iteration() const
#define INT_PARAM_FLAG(name, val, comment)
_ConstTessMemberResultCallback_5_0< false, R, T1, P1, P2, P3, P4, P5 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)(P1, P2, P3, P4, P5) const, typename Identity< P1 >::type p1, typename Identity< P2 >::type p2, typename Identity< P3 >::type p3, typename Identity< P4 >::type p4, typename Identity< P5 >::type p5)
Definition: tesscallback.h:258
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:151
int main(int argc, char **argv)
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:259
const char * c_str() const
Definition: strngs.cpp:205