tesseract  4.1.0
tesseract::LSTMRecognizer Class Reference

#include <lstmrecognizer.h>

Inheritance diagram for tesseract::LSTMRecognizer:
tesseract::LSTMTrainer

Public Member Functions

 LSTMRecognizer ()
 
 ~LSTMRecognizer ()
 
int NumOutputs () const
 
int training_iteration () const
 
int sample_iteration () const
 
double learning_rate () const
 
LossType OutputLossType () const
 
bool SimpleTextOutput () const
 
bool IsIntMode () const
 
bool IsRecoding () const
 
bool IsTensorFlow () const
 
GenericVector< STRINGEnumerateLayers () const
 
NetworkGetLayer (const STRING &id) const
 
float GetLayerLearningRate (const STRING &id) const
 
void ScaleLearningRate (double factor)
 
void ScaleLayerLearningRate (const STRING &id, double factor)
 
void ConvertToInt ()
 
const UNICHARSETGetUnicharset () const
 
const UnicharCompressGetRecoder () const
 
const DictGetDict () const
 
void SetIteration (int iteration)
 
int NumInputs () const
 
int null_char () const
 
bool Load (const ParamsVectors *params, const char *lang, TessdataManager *mgr)
 
bool Serialize (const TessdataManager *mgr, TFile *fp) const
 
bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
bool LoadCharsets (const TessdataManager *mgr)
 
bool LoadRecoder (TFile *fp)
 
bool LoadDictionary (const ParamsVectors *params, const char *lang, TessdataManager *mgr)
 
void RecognizeLine (const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
 
void OutputStats (const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
 
bool RecognizeLine (const ImageData &image_data, bool invert, bool debug, bool re_invert, bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs)
 
STRING DecodeLabels (const GenericVector< int > &labels)
 
void DisplayForward (const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
 
void LabelsFromOutputs (const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
 

Protected Member Functions

void SetRandomSeed ()
 
void DisplayLSTMOutput (const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
 
void DebugActivationPath (const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
 
void DebugActivationRange (const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
 
void LabelsViaReEncode (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
void LabelsViaSimpleText (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
const char * DecodeLabel (const GenericVector< int > &labels, int start, int *end, int *decoded)
 
const char * DecodeSingleLabel (int label)
 

Protected Attributes

Networknetwork_
 
CCUtil ccutil_
 
UnicharCompress recoder_
 
STRING network_str_
 
int32_t training_flags_
 
int32_t training_iteration_
 
int32_t sample_iteration_
 
int32_t null_char_
 
float learning_rate_
 
float momentum_
 
float adam_beta_
 
TRand randomizer_
 
NetworkScratch scratch_space_
 
Dictdict_
 
RecodeBeamSearchsearch_
 
ScrollViewdebug_win_
 

Detailed Description

Definition at line 54 of file lstmrecognizer.h.

Constructor & Destructor Documentation

tesseract::LSTMRecognizer::LSTMRecognizer ( )

Definition at line 49 of file lstmrecognizer.cpp.

tesseract::LSTMRecognizer::~LSTMRecognizer ( )

Definition at line 62 of file lstmrecognizer.cpp.

62  {
63  delete network_;
64  delete dict_;
65  delete search_;
66 }
RecodeBeamSearch * search_

Member Function Documentation

void tesseract::LSTMRecognizer::ConvertToInt ( )
inline

Definition at line 124 of file lstmrecognizer.h.

124  {
125  if ((training_flags_ & TF_INT_MODE) == 0) {
128  }
129  }
virtual void ConvertToInt()
Definition: network.h:190
void tesseract::LSTMRecognizer::DebugActivationPath ( const NetworkIO outputs,
const GenericVector< int > &  labels,
const GenericVector< int > &  xcoords 
)
protected

Definition at line 355 of file lstmrecognizer.cpp.

357  {
358  if (xcoords[0] > 0)
359  DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
360  int end = 1;
361  for (int start = 0; start < labels.size(); start = end) {
362  if (labels[start] == null_char_) {
363  end = start + 1;
364  DebugActivationRange(outputs, "<null>", null_char_, xcoords[start],
365  xcoords[end]);
366  continue;
367  } else {
368  int decoded;
369  const char* label = DecodeLabel(labels, start, &end, &decoded);
370  DebugActivationRange(outputs, label, labels[start], xcoords[start],
371  xcoords[start + 1]);
372  for (int i = start + 1; i < end; ++i) {
373  DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i],
374  xcoords[i], xcoords[i + 1]);
375  }
376  }
377  }
378 }
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
const char * DecodeSingleLabel(int label)
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
int size() const
Definition: genericvector.h:70
void tesseract::LSTMRecognizer::DebugActivationRange ( const NetworkIO outputs,
const char *  label,
int  best_choice,
int  x_start,
int  x_end 
)
protected

Definition at line 382 of file lstmrecognizer.cpp.

384  {
385  tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
386  double max_score = 0.0;
387  double mean_score = 0.0;
388  const int width = x_end - x_start;
389  for (int x = x_start; x < x_end; ++x) {
390  const float* line = outputs.f(x);
391  const double score = line[best_choice] * 100.0;
392  if (score > max_score) max_score = score;
393  mean_score += score / width;
394  int best_c = 0;
395  double best_score = 0.0;
396  for (int c = 0; c < outputs.NumFeatures(); ++c) {
397  if (c != best_choice && line[c] > best_score) {
398  best_c = c;
399  best_score = line[c];
400  }
401  }
402  tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c,
403  best_score * 100.0);
404  }
405  tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
406 }
const char * DecodeSingleLabel(int label)
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:36
const char * tesseract::LSTMRecognizer::DecodeLabel ( const GenericVector< int > &  labels,
int  start,
int *  end,
int *  decoded 
)
protected

Definition at line 470 of file lstmrecognizer.cpp.

471  {
472  *end = start + 1;
473  if (IsRecoding()) {
474  // Decode labels via recoder_.
475  RecodedCharID code;
476  if (labels[start] == null_char_) {
477  if (decoded != nullptr) {
478  code.Set(0, null_char_);
479  *decoded = recoder_.DecodeUnichar(code);
480  }
481  return "<null>";
482  }
483  int index = start;
484  while (index < labels.size() &&
485  code.length() < RecodedCharID::kMaxCodeLen) {
486  code.Set(code.length(), labels[index++]);
487  while (index < labels.size() && labels[index] == null_char_) ++index;
488  int uni_id = recoder_.DecodeUnichar(code);
489  // If the next label isn't a valid first code, then we need to continue
490  // extending even if we have a valid uni_id from this prefix.
491  if (uni_id != INVALID_UNICHAR_ID &&
492  (index == labels.size() ||
493  code.length() == RecodedCharID::kMaxCodeLen ||
494  recoder_.IsValidFirstCode(labels[index]))) {
495  *end = index;
496  if (decoded != nullptr) *decoded = uni_id;
497  if (uni_id == UNICHAR_SPACE) return " ";
498  return GetUnicharset().get_normed_unichar(uni_id);
499  }
500  }
501  return "<Undecodable>";
502  } else {
503  if (decoded != nullptr) *decoded = labels[start];
504  if (labels[start] == null_char_) return "<null>";
505  if (labels[start] == UNICHAR_SPACE) return " ";
506  return GetUnicharset().get_normed_unichar(labels[start]);
507  }
508 }
const UNICHARSET & GetUnicharset() const
bool IsValidFirstCode(int code) const
static const int kMaxCodeLen
int DecodeUnichar(const RecodedCharID &code) const
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
Definition: unicharset.h:828
int size() const
Definition: genericvector.h:70
STRING tesseract::LSTMRecognizer::DecodeLabels ( const GenericVector< int > &  labels)

Definition at line 297 of file lstmrecognizer.cpp.

297  {
298  STRING result;
299  int end = 1;
300  for (int start = 0; start < labels.size(); start = end) {
301  if (labels[start] == null_char_) {
302  end = start + 1;
303  } else {
304  result += DecodeLabel(labels, start, &end, nullptr);
305  }
306  }
307  return result;
308 }
Definition: strngs.h:45
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
int size() const
Definition: genericvector.h:70
const char * tesseract::LSTMRecognizer::DecodeSingleLabel ( int  label)
protected

Definition at line 512 of file lstmrecognizer.cpp.

512  {
513  if (label == null_char_) return "<null>";
514  if (IsRecoding()) {
515  // Decode label via recoder_.
516  RecodedCharID code;
517  code.Set(0, label);
518  label = recoder_.DecodeUnichar(code);
519  if (label == INVALID_UNICHAR_ID) return ".."; // Part of a bigger code.
520  }
521  if (label == UNICHAR_SPACE) return " ";
522  return GetUnicharset().get_normed_unichar(label);
523 }
const UNICHARSET & GetUnicharset() const
int DecodeUnichar(const RecodedCharID &code) const
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
Definition: unicharset.h:828
bool tesseract::LSTMRecognizer::DeSerialize ( const TessdataManager mgr,
TFile fp 
)

Definition at line 100 of file lstmrecognizer.cpp.

100  {
101  delete network_;
103  if (network_ == nullptr) return false;
104  bool include_charsets = mgr == nullptr ||
105  !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
106  !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
107  if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false))
108  return false;
109  if (!network_str_.DeSerialize(fp)) return false;
110  if (!fp->DeSerialize(&training_flags_)) return false;
111  if (!fp->DeSerialize(&training_iteration_)) return false;
112  if (!fp->DeSerialize(&sample_iteration_)) return false;
113  if (!fp->DeSerialize(&null_char_)) return false;
114  if (!fp->DeSerialize(&adam_beta_)) return false;
115  if (!fp->DeSerialize(&learning_rate_)) return false;
116  if (!fp->DeSerialize(&momentum_)) return false;
117  if (include_charsets && !LoadRecoder(fp)) return false;
118  if (!include_charsets && !LoadCharsets(mgr)) return false;
121  return true;
122 }
bool LoadCharsets(const TessdataManager *mgr)
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
UNICHARSET unicharset
Definition: ccutil.h:71
virtual void CacheXScaleFactor(int factor)
Definition: network.h:214
bool DeSerialize(bool swap, FILE *fp)
Definition: strngs.cpp:159
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:388
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
virtual int XScaleFactor() const
Definition: network.h:208
void tesseract::LSTMRecognizer::DisplayForward ( const NetworkIO inputs,
const GenericVector< int > &  labels,
const GenericVector< int > &  label_coords,
const char *  window_name,
ScrollView **  window 
)

Definition at line 312 of file lstmrecognizer.cpp.

316  {
317 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
318  Pix* input_pix = inputs.ToPix();
319  Network::ClearWindow(false, window_name, pixGetWidth(input_pix),
320  pixGetHeight(input_pix), window);
321  int line_height = Network::DisplayImage(input_pix, *window);
322  DisplayLSTMOutput(labels, label_coords, line_height, *window);
323 #endif // GRAPHICS_DISABLED
324 }
void DisplayLSTMOutput(const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:312
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:335
void tesseract::LSTMRecognizer::DisplayLSTMOutput ( const GenericVector< int > &  labels,
const GenericVector< int > &  xcoords,
int  height,
ScrollView window 
)
protected

Definition at line 328 of file lstmrecognizer.cpp.

330  {
331 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
332  int x_scale = network_->XScaleFactor();
333  window->TextAttributes("Arial", height / 4, false, false, false);
334  int end = 1;
335  for (int start = 0; start < labels.size(); start = end) {
336  int xpos = xcoords[start] * x_scale;
337  if (labels[start] == null_char_) {
338  end = start + 1;
339  window->Pen(ScrollView::RED);
340  } else {
341  window->Pen(ScrollView::GREEN);
342  const char* str = DecodeLabel(labels, start, &end, nullptr);
343  if (*str == '\\') str = "\\\\";
344  xpos = xcoords[(start + end) / 2] * x_scale;
345  window->Text(xpos, height, str);
346  }
347  window->Line(xpos, 0, xpos, height * 3 / 2);
348  }
349  window->Update();
350 #endif // GRAPHICS_DISABLED
351 }
void Text(int x, int y, const char *mystring)
Definition: scrollview.cpp:652
static void Update()
Definition: scrollview.cpp:709
void TextAttributes(const char *font, int pixel_size, bool bold, bool italic, bool underlined)
Definition: scrollview.cpp:635
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
void Pen(Color color)
Definition: scrollview.cpp:719
virtual int XScaleFactor() const
Definition: network.h:208
void Line(int x1, int y1, int x2, int y2)
Definition: scrollview.cpp:532
int size() const
Definition: genericvector.h:70
GenericVector<STRING> tesseract::LSTMRecognizer::EnumerateLayers ( ) const
inline

Definition at line 79 of file lstmrecognizer.h.

79  {
80  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
81  auto* series = static_cast<Series*>(network_);
82  GenericVector<STRING> layers;
83  series->EnumerateLayers(nullptr, &layers);
84  return layers;
85  }
#define ASSERT_HOST(x)
Definition: errcode.h:88
NetworkType type() const
Definition: network.h:111
const Dict* tesseract::LSTMRecognizer::GetDict ( ) const
inline

Definition at line 136 of file lstmrecognizer.h.

136 { return dict_; }
Network* tesseract::LSTMRecognizer::GetLayer ( const STRING id) const
inline

Definition at line 87 of file lstmrecognizer.h.

87  {
88  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
89  ASSERT_HOST(id.length() > 1 && id[0] == ':');
90  auto* series = static_cast<Series*>(network_);
91  return series->GetLayer(&id[1]);
92  }
#define ASSERT_HOST(x)
Definition: errcode.h:88
NetworkType type() const
Definition: network.h:111
float tesseract::LSTMRecognizer::GetLayerLearningRate ( const STRING id) const
inline

Definition at line 94 of file lstmrecognizer.h.

94  {
95  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
97  ASSERT_HOST(id.length() > 1 && id[0] == ':');
98  auto* series = static_cast<Series*>(network_);
99  return series->LayerLearningRate(&id[1]);
100  } else {
101  return learning_rate_;
102  }
103  }
bool TestFlag(NetworkFlags flag) const
Definition: network.h:143
#define ASSERT_HOST(x)
Definition: errcode.h:88
NetworkType type() const
Definition: network.h:111
const UnicharCompress& tesseract::LSTMRecognizer::GetRecoder ( ) const
inline

Definition at line 134 of file lstmrecognizer.h.

134 { return recoder_; }
const UNICHARSET& tesseract::LSTMRecognizer::GetUnicharset ( ) const
inline

Definition at line 132 of file lstmrecognizer.h.

132 { return ccutil_.unicharset; }
UNICHARSET unicharset
Definition: ccutil.h:71
bool tesseract::LSTMRecognizer::IsIntMode ( ) const
inline

Definition at line 70 of file lstmrecognizer.h.

bool tesseract::LSTMRecognizer::IsRecoding ( ) const
inline

Definition at line 72 of file lstmrecognizer.h.

bool tesseract::LSTMRecognizer::IsTensorFlow ( ) const
inline

Definition at line 76 of file lstmrecognizer.h.

76 { return network_->type() == NT_TENSORFLOW; }
NetworkType type() const
Definition: network.h:111
void tesseract::LSTMRecognizer::LabelsFromOutputs ( const NetworkIO outputs,
GenericVector< int > *  labels,
GenericVector< int > *  xcoords 
)

Definition at line 425 of file lstmrecognizer.cpp.

427  {
428  if (SimpleTextOutput()) {
429  LabelsViaSimpleText(outputs, labels, xcoords);
430  } else {
431  LabelsViaReEncode(outputs, labels, xcoords);
432  }
433 }
void LabelsViaReEncode(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
bool SimpleTextOutput() const
void LabelsViaSimpleText(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
void tesseract::LSTMRecognizer::LabelsViaReEncode ( const NetworkIO output,
GenericVector< int > *  labels,
GenericVector< int > *  xcoords 
)
protected

Definition at line 437 of file lstmrecognizer.cpp.

439  {
440  if (search_ == nullptr) {
441  search_ =
442  new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
443  }
444  search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, nullptr);
445  search_->ExtractBestPathAsLabels(labels, xcoords);
446 }
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:82
static const float kMinCertainty
Definition: recodebeam.h:222
RecodeBeamSearch * search_
bool SimpleTextOutput() const
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
Definition: recodebeam.cpp:139
void tesseract::LSTMRecognizer::LabelsViaSimpleText ( const NetworkIO output,
GenericVector< int > *  labels,
GenericVector< int > *  xcoords 
)
protected

Definition at line 451 of file lstmrecognizer.cpp.

453  {
454  labels->truncate(0);
455  xcoords->truncate(0);
456  const int width = output.Width();
457  for (int t = 0; t < width; ++t) {
458  float score = 0.0f;
459  const int label = output.BestLabel(t, &score);
460  if (label != null_char_) {
461  labels->push_back(label);
462  xcoords->push_back(t);
463  }
464  }
465  xcoords->push_back(width);
466 }
void truncate(int size)
int push_back(T object)
double tesseract::LSTMRecognizer::learning_rate ( ) const
inline

Definition at line 62 of file lstmrecognizer.h.

62 { return learning_rate_; }
bool tesseract::LSTMRecognizer::Load ( const ParamsVectors params,
const char *  lang,
TessdataManager mgr 
)

Definition at line 69 of file lstmrecognizer.cpp.

70  {
71  TFile fp;
72  if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) return false;
73  if (!DeSerialize(mgr, &fp)) return false;
74  if (lang == nullptr) return true;
75  // Allow it to run without a dictionary.
76  LoadDictionary(params, lang, mgr);
77  return true;
78 }
bool LoadDictionary(const ParamsVectors *params, const char *lang, TessdataManager *mgr)
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
bool tesseract::LSTMRecognizer::LoadCharsets ( const TessdataManager mgr)

Definition at line 125 of file lstmrecognizer.cpp.

125  {
126  TFile fp;
127  if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false;
128  if (!ccutil_.unicharset.load_from_file(&fp, false)) return false;
129  if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false;
130  if (!LoadRecoder(&fp)) return false;
131  return true;
132 }
UNICHARSET unicharset
Definition: ccutil.h:71
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:388
bool tesseract::LSTMRecognizer::LoadDictionary ( const ParamsVectors params,
const char *  lang,
TessdataManager mgr 
)

Definition at line 159 of file lstmrecognizer.cpp.

160  {
161  delete dict_;
162  dict_ = new Dict(&ccutil_);
163  dict_->user_words_file.ResetFrom(params);
164  dict_->user_words_suffix.ResetFrom(params);
165  dict_->user_patterns_file.ResetFrom(params);
166  dict_->user_patterns_suffix.ResetFrom(params);
168  dict_->LoadLSTM(lang, mgr);
169  if (dict_->FinishLoad()) return true; // Success.
170  tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n",
171  lang);
172  delete dict_;
173  dict_ = nullptr;
174  return false;
175 }
char * user_patterns_suffix
Definition: dict.h:573
bool FinishLoad()
Definition: dict.cpp:360
void LoadLSTM(const STRING &lang, TessdataManager *data_file)
Definition: dict.cpp:300
void SetupForLoad(DawgCache *dawg_cache)
Definition: dict.cpp:201
char * user_patterns_file
Definition: dict.h:571
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:36
char * user_words_suffix
Definition: dict.h:569
char * user_words_file
Definition: dict.h:567
static TESS_API DawgCache * GlobalDawgCache()
Definition: dict.cpp:193
bool tesseract::LSTMRecognizer::LoadRecoder ( TFile fp)

Definition at line 135 of file lstmrecognizer.cpp.

135  {
136  if (IsRecoding()) {
137  if (!recoder_.DeSerialize(fp)) return false;
138  RecodedCharID code;
140  if (code(0) != UNICHAR_SPACE) {
141  tprintf("Space was garbled in recoding!!\n");
142  return false;
143  }
144  } else {
147  }
148  return true;
149 }
const UNICHARSET & GetUnicharset() const
void SetupPassThrough(const UNICHARSET &unicharset)
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:36
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
int tesseract::LSTMRecognizer::null_char ( ) const
inline

Definition at line 143 of file lstmrecognizer.h.

143 { return null_char_; }
int tesseract::LSTMRecognizer::NumInputs ( ) const
inline

Definition at line 142 of file lstmrecognizer.h.

142 { return network_->NumInputs(); }
int NumInputs() const
Definition: network.h:119
int tesseract::LSTMRecognizer::NumOutputs ( ) const
inline

Definition at line 59 of file lstmrecognizer.h.

59 { return network_->NumOutputs(); }
int NumOutputs() const
Definition: network.h:122
LossType tesseract::LSTMRecognizer::OutputLossType ( ) const
inline

Definition at line 63 of file lstmrecognizer.h.

63  {
64  if (network_ == nullptr) return LT_NONE;
65  StaticShape shape;
66  shape = network_->OutputShape(shape);
67  return shape.loss_type();
68  }
LossType loss_type() const
Definition: static_shape.h:50
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:132
void tesseract::LSTMRecognizer::OutputStats ( const NetworkIO outputs,
float *  min_output,
float *  mean_output,
float *  sd 
)

Definition at line 201 of file lstmrecognizer.cpp.

202  {
203  const int kOutputScale = INT8_MAX;
204  STATS stats(0, kOutputScale + 1);
205  for (int t = 0; t < outputs.Width(); ++t) {
206  int best_label = outputs.BestLabel(t, nullptr);
207  if (best_label != null_char_) {
208  float best_output = outputs.f(t)[best_label];
209  stats.add(static_cast<int>(kOutputScale * best_output), 1);
210  }
211  }
212  // If the output is all nulls it could be that the photometric interpretation
213  // is wrong, so make it look bad, so the other way can win, even if not great.
214  if (stats.get_total() == 0) {
215  *min_output = 0.0f;
216  *mean_output = 0.0f;
217  *sd = 1.0f;
218  } else {
219  *min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
220  *mean_output = stats.mean() / kOutputScale;
221  *sd = stats.sd() / kOutputScale;
222  }
223 }
Definition: statistc.h:31
void tesseract::LSTMRecognizer::RecognizeLine ( const ImageData image_data,
bool  invert,
bool  debug,
double  worst_dict_cert,
const TBOX line_box,
PointerVector< WERD_RES > *  words,
int  lstm_choice_mode = 0 
)

Definition at line 179 of file lstmrecognizer.cpp.

183  {
184  NetworkIO outputs;
185  float scale_factor;
186  NetworkIO inputs;
187  if (!RecognizeLine(image_data, invert, debug, false, false, &scale_factor,
188  &inputs, &outputs))
189  return;
190  if (search_ == nullptr) {
191  search_ =
192  new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
193  }
194  search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert,
195  &GetUnicharset(), lstm_choice_mode);
196  search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
197  &GetUnicharset(), words, lstm_choice_mode);
198 }
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:82
const double kDictRatio
const UNICHARSET & GetUnicharset() const
RecodeBeamSearch * search_
bool SimpleTextOutput() const
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
Definition: recodebeam.cpp:177
const double kCertOffset
bool tesseract::LSTMRecognizer::RecognizeLine ( const ImageData image_data,
bool  invert,
bool  debug,
bool  re_invert,
bool  upside_down,
float *  scale_factor,
NetworkIO inputs,
NetworkIO outputs 
)

Definition at line 227 of file lstmrecognizer.cpp.

230  {
231  // Maximum width of image to train on.
232  const int kMaxImageWidth = 2560;
233  // This ensures consistent recognition results.
234  SetRandomSeed();
235  int min_width = network_->XScaleFactor();
236  Pix* pix = Input::PrepareLSTMInputs(image_data, network_, min_width,
237  &randomizer_, scale_factor);
238  if (pix == nullptr) {
239  tprintf("Line cannot be recognized!!\n");
240  return false;
241  }
242  if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
243  tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix),
244  pixGetHeight(pix));
245  pixDestroy(&pix);
246  return false;
247  }
248  if (upside_down) pixRotate180(pix, pix);
249  // Reduction factor from image to coords.
250  *scale_factor = min_width / *scale_factor;
251  inputs->set_int_mode(IsIntMode());
252  SetRandomSeed();
254  network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
255  // Check for auto inversion.
256  float pos_min, pos_mean, pos_sd;
257  OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
258  if (invert && pos_min < 0.5) {
259  // Run again inverted and see if it is any better.
260  NetworkIO inv_inputs, inv_outputs;
261  inv_inputs.set_int_mode(IsIntMode());
262  SetRandomSeed();
263  pixInvert(pix, pix);
265  &inv_inputs);
266  network_->Forward(debug, inv_inputs, nullptr, &scratch_space_,
267  &inv_outputs);
268  float inv_min, inv_mean, inv_sd;
269  OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
270  if (inv_min > pos_min && inv_mean > pos_mean && inv_sd < pos_sd) {
271  // Inverted did better. Use inverted data.
272  if (debug) {
273  tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n",
274  pos_min, pos_mean, pos_sd, inv_min, inv_mean, inv_sd);
275  }
276  *outputs = inv_outputs;
277  *inputs = inv_inputs;
278  } else if (re_invert) {
279  // Inverting was not an improvement, so undo and run again, so the
280  // outputs match the best forward result.
281  SetRandomSeed();
282  network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
283  }
284  }
285  pixDestroy(&pix);
286  if (debug) {
287  GenericVector<int> labels, coords;
288  LabelsFromOutputs(*outputs, &labels, &coords);
289  DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
290  DebugActivationPath(*outputs, labels, coords);
291  }
292  return true;
293 }
NetworkScratch scratch_space_
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
static void PreparePixInput(const StaticShape &shape, const Pix *pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:111
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
virtual StaticShape InputShape() const
Definition: network.h:126
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:36
static Pix * PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:83
bool IsTraining() const
Definition: network.h:114
virtual int XScaleFactor() const
Definition: network.h:208
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
int tesseract::LSTMRecognizer::sample_iteration ( ) const
inline

Definition at line 61 of file lstmrecognizer.h.

61 { return sample_iteration_; }
void tesseract::LSTMRecognizer::ScaleLayerLearningRate ( const STRING id,
double  factor 
)
inline

Definition at line 116 of file lstmrecognizer.h.

116  {
117  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
118  ASSERT_HOST(id.length() > 1 && id[0] == ':');
119  auto* series = static_cast<Series*>(network_);
120  series->ScaleLayerLearningRate(&id[1], factor);
121  }
#define ASSERT_HOST(x)
Definition: errcode.h:88
NetworkType type() const
Definition: network.h:111
void tesseract::LSTMRecognizer::ScaleLearningRate ( double  factor)
inline

Definition at line 105 of file lstmrecognizer.h.

105  {
106  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
107  learning_rate_ *= factor;
110  for (int i = 0; i < layers.size(); ++i) {
111  ScaleLayerLearningRate(layers[i], factor);
112  }
113  }
114  }
GenericVector< STRING > EnumerateLayers() const
bool TestFlag(NetworkFlags flag) const
Definition: network.h:143
void ScaleLayerLearningRate(const STRING &id, double factor)
#define ASSERT_HOST(x)
Definition: errcode.h:88
NetworkType type() const
Definition: network.h:111
int size() const
Definition: genericvector.h:70
bool tesseract::LSTMRecognizer::Serialize ( const TessdataManager mgr,
TFile fp 
) const

Definition at line 81 of file lstmrecognizer.cpp.

81  {
82  bool include_charsets = mgr == nullptr ||
83  !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
84  !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
85  if (!network_->Serialize(fp)) return false;
86  if (include_charsets && !GetUnicharset().save_to_file(fp)) return false;
87  if (!network_str_.Serialize(fp)) return false;
88  if (!fp->Serialize(&training_flags_)) return false;
89  if (!fp->Serialize(&training_iteration_)) return false;
90  if (!fp->Serialize(&sample_iteration_)) return false;
91  if (!fp->Serialize(&null_char_)) return false;
92  if (!fp->Serialize(&adam_beta_)) return false;
93  if (!fp->Serialize(&learning_rate_)) return false;
94  if (!fp->Serialize(&momentum_)) return false;
95  if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false;
96  return true;
97 }
bool save_to_file(const char *const filename) const
Definition: unicharset.h:350
bool Serialize(TFile *fp) const
const UNICHARSET & GetUnicharset() const
bool Serialize(FILE *fp) const
Definition: strngs.cpp:146
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
void tesseract::LSTMRecognizer::SetIteration ( int  iteration)
inline

Definition at line 140 of file lstmrecognizer.h.

140 { sample_iteration_ = iteration; }
void tesseract::LSTMRecognizer::SetRandomSeed ( )
inlineprotected

Definition at line 214 of file lstmrecognizer.h.

214  {
215  int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
216  randomizer_.set_seed(seed);
218  }
int32_t IntRand()
Definition: helpers.h:50
void set_seed(uint64_t seed)
Definition: helpers.h:40
bool tesseract::LSTMRecognizer::SimpleTextOutput ( ) const
inline

Definition at line 69 of file lstmrecognizer.h.

int tesseract::LSTMRecognizer::training_iteration ( ) const
inline

Definition at line 60 of file lstmrecognizer.h.

60 { return training_iteration_; }

Member Data Documentation

float tesseract::LSTMRecognizer::adam_beta_
protected

Definition at line 283 of file lstmrecognizer.h.

CCUtil tesseract::LSTMRecognizer::ccutil_
protected

Definition at line 261 of file lstmrecognizer.h.

ScrollView* tesseract::LSTMRecognizer::debug_win_
protected

Definition at line 295 of file lstmrecognizer.h.

Dict* tesseract::LSTMRecognizer::dict_
protected

Definition at line 289 of file lstmrecognizer.h.

float tesseract::LSTMRecognizer::learning_rate_
protected

Definition at line 280 of file lstmrecognizer.h.

float tesseract::LSTMRecognizer::momentum_
protected

Definition at line 281 of file lstmrecognizer.h.

Network* tesseract::LSTMRecognizer::network_
protected

Definition at line 258 of file lstmrecognizer.h.

STRING tesseract::LSTMRecognizer::network_str_
protected

Definition at line 268 of file lstmrecognizer.h.

int32_t tesseract::LSTMRecognizer::null_char_
protected

Definition at line 278 of file lstmrecognizer.h.

TRand tesseract::LSTMRecognizer::randomizer_
protected

Definition at line 286 of file lstmrecognizer.h.

UnicharCompress tesseract::LSTMRecognizer::recoder_
protected

Definition at line 265 of file lstmrecognizer.h.

int32_t tesseract::LSTMRecognizer::sample_iteration_
protected

Definition at line 275 of file lstmrecognizer.h.

NetworkScratch tesseract::LSTMRecognizer::scratch_space_
protected

Definition at line 287 of file lstmrecognizer.h.

RecodeBeamSearch* tesseract::LSTMRecognizer::search_
protected

Definition at line 291 of file lstmrecognizer.h.

int32_t tesseract::LSTMRecognizer::training_flags_
protected

Definition at line 271 of file lstmrecognizer.h.

int32_t tesseract::LSTMRecognizer::training_iteration_
protected

Definition at line 273 of file lstmrecognizer.h.


The documentation for this class was generated from the following files: