tesseract  4.1.0
lstmrecognizer.h
Go to the documentation of this file.
1 // File: lstmrecognizer.h
3 // Description: Top-level line recognizer class for LSTM-based networks.
4 // Author: Ray Smith
5 // Created: Thu May 02 08:57:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
20 #define TESSERACT_LSTM_LSTMRECOGNIZER_H_
21 
22 #include "ccutil.h"
23 #include "helpers.h"
24 #include "imagedata.h"
25 #include "matrix.h"
26 #include "network.h"
27 #include "networkscratch.h"
28 #include "params.h"
29 #include "recodebeam.h"
30 #include "series.h"
31 #include "strngs.h"
32 #include "unicharcompress.h"
33 
34 class BLOB_CHOICE_IT;
35 struct Pix;
36 class ROW_RES;
37 class ScrollView;
38 class TBOX;
39 class WERD_RES;
40 
41 namespace tesseract {
42 
43 class Dict;
44 class ImageData;
45 
46 // Enum indicating training mode control flags.
50 };
51 
52 // Top-level line recognizer class for LSTM-based networks.
53 // Note that a sub-class, LSTMTrainer is used for training.
55  public:
58 
59  int NumOutputs() const { return network_->NumOutputs(); }
60  int training_iteration() const { return training_iteration_; }
61  int sample_iteration() const { return sample_iteration_; }
62  double learning_rate() const { return learning_rate_; }
64  if (network_ == nullptr) return LT_NONE;
65  StaticShape shape;
66  shape = network_->OutputShape(shape);
67  return shape.loss_type();
68  }
69  bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
70  bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
71  // True if recoder_ is active to re-encode text to a smaller space.
72  bool IsRecoding() const {
74  }
75  // Returns true if the network is a TensorFlow network.
76  bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
77  // Returns a vector of layer ids that can be passed to other layer functions
78  // to access a specific layer.
80  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
81  auto* series = static_cast<Series*>(network_);
82  GenericVector<STRING> layers;
83  series->EnumerateLayers(nullptr, &layers);
84  return layers;
85  }
86  // Returns a specific layer from its id (from EnumerateLayers).
87  Network* GetLayer(const STRING& id) const {
88  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
89  ASSERT_HOST(id.length() > 1 && id[0] == ':');
90  auto* series = static_cast<Series*>(network_);
91  return series->GetLayer(&id[1]);
92  }
93  // Returns the learning rate of the layer from its id.
94  float GetLayerLearningRate(const STRING& id) const {
95  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
97  ASSERT_HOST(id.length() > 1 && id[0] == ':');
98  auto* series = static_cast<Series*>(network_);
99  return series->LayerLearningRate(&id[1]);
100  } else {
101  return learning_rate_;
102  }
103  }
104  // Multiplies the all the learning rate(s) by the given factor.
105  void ScaleLearningRate(double factor) {
106  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
107  learning_rate_ *= factor;
110  for (int i = 0; i < layers.size(); ++i) {
111  ScaleLayerLearningRate(layers[i], factor);
112  }
113  }
114  }
115  // Multiplies the learning rate of the layer with id, by the given factor.
116  void ScaleLayerLearningRate(const STRING& id, double factor) {
117  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
118  ASSERT_HOST(id.length() > 1 && id[0] == ':');
119  auto* series = static_cast<Series*>(network_);
120  series->ScaleLayerLearningRate(&id[1], factor);
121  }
122 
123  // Converts the network to int if not already.
124  void ConvertToInt() {
125  if ((training_flags_ & TF_INT_MODE) == 0) {
128  }
129  }
130 
131  // Provides access to the UNICHARSET that this classifier works with.
132  const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
133  // Provides access to the UnicharCompress that this classifier works with.
134  const UnicharCompress& GetRecoder() const { return recoder_; }
135  // Provides access to the Dict that this classifier works with.
136  const Dict* GetDict() const { return dict_; }
137  // Sets the sample iteration to the given value. The sample_iteration_
138  // determines the seed for the random number generator. The training
139  // iteration is incremented only by a successful training iteration.
140  void SetIteration(int iteration) { sample_iteration_ = iteration; }
141  // Accessors for textline image normalization.
142  int NumInputs() const { return network_->NumInputs(); }
143  int null_char() const { return null_char_; }
144 
145  // Loads a model from mgr, including the dictionary only if lang is not null.
146  bool Load(const ParamsVectors* params, const char* lang,
147  TessdataManager* mgr);
148 
149  // Writes to the given file. Returns false in case of error.
150  // If mgr contains a unicharset and recoder, then they are not encoded to fp.
151  bool Serialize(const TessdataManager* mgr, TFile* fp) const;
152  // Reads from the given file. Returns false in case of error.
153  // If mgr contains a unicharset and recoder, then they are taken from there,
154  // otherwise, they are part of the serialization in fp.
155  bool DeSerialize(const TessdataManager* mgr, TFile* fp);
156  // Loads the charsets from mgr.
157  bool LoadCharsets(const TessdataManager* mgr);
158  // Loads the Recoder.
159  bool LoadRecoder(TFile* fp);
160  // Loads the dictionary if possible from the traineddata file.
161  // Prints a warning message, and returns false but otherwise fails silently
162  // and continues to work without it if loading fails.
163  // Note that dictionary load is independent from DeSerialize, but dependent
164  // on the unicharset matching. This enables training to deserialize a model
165  // from checkpoint or restore without having to go back and reload the
166  // dictionary.
167  bool LoadDictionary(const ParamsVectors* params, const char* lang,
168  TessdataManager* mgr);
169 
170  // Recognizes the line image, contained within image_data, returning the
171  // recognized tesseract WERD_RES for the words.
172  // If invert, tries inverted as well if the normal interpretation doesn't
173  // produce a good enough result. The line_box is used for computing the
174  // box_word in the output words. worst_dict_cert is the worst certainty that
175  // will be used in a dictionary word.
176  void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
177  double worst_dict_cert, const TBOX& line_box,
178  PointerVector<WERD_RES>* words, int lstm_choice_mode = 0);
179 
180  // Helper computes min and mean best results in the output.
181  void OutputStats(const NetworkIO& outputs, float* min_output,
182  float* mean_output, float* sd);
183  // Recognizes the image_data, returning the labels,
184  // scores, and corresponding pairs of start, end x-coords in coords.
185  // Returned in scale_factor is the reduction factor
186  // between the image and the output coords, for computing bounding boxes.
187  // If re_invert is true, the input is inverted back to its original
188  // photometric interpretation if inversion is attempted but fails to
189  // improve the results. This ensures that outputs contains the correct
190  // forward outputs for the best photometric interpretation.
191  // inputs is filled with the used inputs to the network.
192  bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
193  bool re_invert, bool upside_down, float* scale_factor,
194  NetworkIO* inputs, NetworkIO* outputs);
195 
196  // Converts an array of labels to utf-8, whether or not the labels are
197  // augmented with character boundaries.
198  STRING DecodeLabels(const GenericVector<int>& labels);
199 
200  // Displays the forward results in a window with the characters and
201  // boundaries as determined by the labels and label_coords.
202  void DisplayForward(const NetworkIO& inputs, const GenericVector<int>& labels,
203  const GenericVector<int>& label_coords,
204  const char* window_name, ScrollView** window);
205  // Converts the network output to a sequence of labels. Outputs labels, scores
206  // and start xcoords of each char, and each null_char_, with an additional
207  // final xcoord for the end of the output.
208  // The conversion method is determined by internal state.
209  void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
210  GenericVector<int>* xcoords);
211 
212  protected:
213  // Sets the random seed from the sample_iteration_;
214  void SetRandomSeed() {
215  int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
216  randomizer_.set_seed(seed);
218  }
219 
220  // Displays the labels and cuts at the corresponding xcoords.
221  // Size of labels should match xcoords.
222  void DisplayLSTMOutput(const GenericVector<int>& labels,
223  const GenericVector<int>& xcoords, int height,
224  ScrollView* window);
225 
226  // Prints debug output detailing the activation path that is implied by the
227  // xcoords.
228  void DebugActivationPath(const NetworkIO& outputs,
229  const GenericVector<int>& labels,
230  const GenericVector<int>& xcoords);
231 
232  // Prints debug output detailing activations and 2nd choice over a range
233  // of positions.
234  void DebugActivationRange(const NetworkIO& outputs, const char* label,
235  int best_choice, int x_start, int x_end);
236 
237  // As LabelsViaCTC except that this function constructs the best path that
238  // contains only legal sequences of subcodes for recoder_.
239  void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
240  GenericVector<int>* xcoords);
241  // Converts the network output to a sequence of labels, with scores, using
242  // the simple character model (each position is a char, and the null_char_ is
243  // mainly intended for tail padding.)
244  void LabelsViaSimpleText(const NetworkIO& output, GenericVector<int>* labels,
245  GenericVector<int>* xcoords);
246 
247  // Returns a string corresponding to the label starting at start. Sets *end
248  // to the next start and if non-null, *decoded to the unichar id.
249  const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
250  int* decoded);
251 
252  // Returns a string corresponding to a given single label id, falling back to
253  // a default of ".." for part of a multi-label unichar-id.
254  const char* DecodeSingleLabel(int label);
255 
256  protected:
257  // The network hierarchy.
259  // The unicharset. Only the unicharset element is serialized.
260  // Has to be a CCUtil, so Dict can point to it.
262  // For backward compatibility, recoder_ is serialized iff
263  // training_flags_ & TF_COMPRESS_UNICHARSET.
264  // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
266 
267  // ==Training parameters that are serialized to provide a record of them.==
269  // Flags used to determine the training method of the network.
270  // See enum TrainingFlags above.
272  // Number of actual backward training steps used.
274  // Index into training sample set. sample_iteration >= training_iteration_.
276  // Index in softmax of null character. May take the value UNICHAR_BROKEN or
277  // ccutil_.unicharset.size().
278  int32_t null_char_;
279  // Learning rate and momentum multipliers of deltas in backprop.
281  float momentum_;
282  // Smoothing factor for 2nd moment of gradients.
283  float adam_beta_;
284 
285  // === NOT SERIALIZED.
288  // Language model (optional) to use with the beam search.
290  // Beam search held between uses to optimize memory allocation/use.
292 
293  // == Debugging parameters.==
294  // Recognition debug display window.
296 };
297 
298 } // namespace tesseract.
299 
300 #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
NetworkScratch scratch_space_
GenericVector< STRING > EnumerateLayers() const
const Dict * GetDict() const
Definition: rect.h:34
bool TestFlag(NetworkFlags flag) const
Definition: network.h:143
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
double learning_rate() const
const char * DecodeSingleLabel(int label)
bool Serialize(const TessdataManager *mgr, TFile *fp) const
Definition: strngs.h:45
bool LoadCharsets(const TessdataManager *mgr)
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
virtual void ConvertToInt()
Definition: network.h:190
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
const UNICHARSET & GetUnicharset() const
void LabelsViaReEncode(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
UNICHARSET unicharset
Definition: ccutil.h:71
float LayerLearningRate(const char *id) const
Definition: plumbing.h:105
void SetIteration(int iteration)
STRING DecodeLabels(const GenericVector< int > &labels)
void ScaleLayerLearningRate(const STRING &id, double factor)
float GetLayerLearningRate(const STRING &id) const
RecodeBeamSearch * search_
int32_t IntRand()
Definition: helpers.h:50
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
bool SimpleTextOutput() const
bool Load(const ParamsVectors *params, const char *lang, TessdataManager *mgr)
void LabelsViaSimpleText(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
Network * GetLayer(const char *id) const
Definition: plumbing.cpp:155
LossType OutputLossType() const
int NumInputs() const
Definition: network.h:119
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
LossType loss_type() const
Definition: static_shape.h:50
void DisplayLSTMOutput(const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
#define ASSERT_HOST(x)
Definition: errcode.h:88
void set_seed(uint64_t seed)
Definition: helpers.h:40
void ScaleLearningRate(double factor)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:132
int training_iteration() const
bool LoadDictionary(const ParamsVectors *params, const char *lang, TessdataManager *mgr)
NetworkType type() const
Definition: network.h:111
const UnicharCompress & GetRecoder() const
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
int size() const
Definition: genericvector.h:70
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
int NumOutputs() const
Definition: network.h:122
Network * GetLayer(const STRING &id) const
void ScaleLayerLearningRate(const char *id, double factor)
Definition: plumbing.h:111