tesseract 3.04.01

cube/conv_net_classifier.cpp

Go to the documentation of this file.
00001 /**********************************************************************
00002  * File:        charclassifier.cpp
00003  * Description: Implementation of Convolutional-NeuralNet Character Classifier
00004  * Author:    Ahmad Abdulkader
00005  * Created:   2007
00006  *
00007  * (C) Copyright 2008, Google Inc.
00008  ** Licensed under the Apache License, Version 2.0 (the "License");
00009  ** you may not use this file except in compliance with the License.
00010  ** You may obtain a copy of the License at
00011  ** http://www.apache.org/licenses/LICENSE-2.0
00012  ** Unless required by applicable law or agreed to in writing, software
00013  ** distributed under the License is distributed on an "AS IS" BASIS,
00014  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00015  ** See the License for the specific language governing permissions and
00016  ** limitations under the License.
00017  *
00018  **********************************************************************/
00019 
00020 #include <algorithm>
00021 #include <stdio.h>
00022 #include <stdlib.h>
00023 #include <string>
00024 #include <vector>
00025 #include <wctype.h>
00026 
00027 #include "char_set.h"
00028 #include "classifier_base.h"
00029 #include "const.h"
00030 #include "conv_net_classifier.h"
00031 #include "cube_utils.h"
00032 #include "feature_base.h"
00033 #include "feature_bmp.h"
00034 #include "tess_lang_model.h"
00035 
00036 namespace tesseract {
00037 
00038 ConvNetCharClassifier::ConvNetCharClassifier(CharSet *char_set,
00039                                              TuningParams *params,
00040                                              FeatureBase *feat_extract)
00041     : CharClassifier(char_set, params, feat_extract) {
00042   char_net_ = NULL;
00043   net_input_ = NULL;
00044   net_output_ = NULL;
00045 }
00046 
00047 ConvNetCharClassifier::~ConvNetCharClassifier() {
00048   if (char_net_ != NULL) {
00049     delete char_net_;
00050     char_net_ = NULL;
00051   }
00052 
00053   if (net_input_ != NULL) {
00054     delete []net_input_;
00055     net_input_ = NULL;
00056   }
00057 
00058   if (net_output_ != NULL) {
00059     delete []net_output_;
00060     net_output_ = NULL;
00061   }
00062 }
00063 
00069 bool ConvNetCharClassifier::Train(CharSamp *char_samp, int ClassID) {
00070   return false;
00071 }
00072 
00078 bool ConvNetCharClassifier::SetLearnParam(char *var_name, float val) {
00079   // TODO(ahmadab): implementation of parameter initializing.
00080   return false;
00081 }
00082 
00086 void ConvNetCharClassifier::Fold() {
00087   // in case insensitive mode
00088   if (case_sensitive_ == false) {
00089     int class_cnt = char_set_->ClassCount();
00090     // fold case
00091     for (int class_id = 0; class_id < class_cnt; class_id++) {
00092       // get class string
00093       const char_32 *str32 = char_set_->ClassString(class_id);
00094       // get the upper case form of the string
00095       string_32 upper_form32 = str32;
00096       for (int ch = 0; ch < upper_form32.length(); ch++) {
00097         if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) {
00098           upper_form32[ch] = towupper(upper_form32[ch]);
00099         }
00100       }
00101 
00102       // find out the upperform class-id if any
00103       int upper_class_id =
00104           char_set_->ClassID(reinterpret_cast<const char_32 *>(
00105           upper_form32.c_str()));
00106       if (upper_class_id != -1 && class_id != upper_class_id) {
00107         float max_out = MAX(net_output_[class_id], net_output_[upper_class_id]);
00108         net_output_[class_id] = max_out;
00109         net_output_[upper_class_id] = max_out;
00110       }
00111     }
00112   }
00113 
00114   // The folding sets specify how groups of classes should be folded
00115   // Folding involved assigning a min-activation to all the members
00116   // of the folding set. The min-activation is a fraction of the max-activation
00117   // of the members of the folding set
00118   for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
00119     if (fold_set_len_[fold_set] == 0)
00120       continue;
00121     float max_prob = net_output_[fold_sets_[fold_set][0]];
00122     for (int ch = 1; ch < fold_set_len_[fold_set]; ch++) {
00123       if (net_output_[fold_sets_[fold_set][ch]] > max_prob) {
00124         max_prob = net_output_[fold_sets_[fold_set][ch]];
00125       }
00126     }
00127     for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
00128       net_output_[fold_sets_[fold_set][ch]] = MAX(max_prob * kFoldingRatio,
00129           net_output_[fold_sets_[fold_set][ch]]);
00130     }
00131   }
00132 }
00133 
00138 bool ConvNetCharClassifier::RunNets(CharSamp *char_samp) {
00139   if (char_net_ == NULL) {
00140     fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): "
00141             "NeuralNet is NULL\n");
00142     return false;
00143   }
00144   int feat_cnt = char_net_->in_cnt();
00145   int class_cnt = char_set_->ClassCount();
00146 
00147   // allocate i/p and o/p buffers if needed
00148   if (net_input_ == NULL) {
00149     net_input_ = new float[feat_cnt];
00150     if (net_input_ == NULL) {
00151       fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): "
00152             "unable to allocate memory for input nodes\n");
00153       return false;
00154     }
00155 
00156     net_output_ = new float[class_cnt];
00157     if (net_output_ == NULL) {
00158       fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): "
00159             "unable to allocate memory for output nodes\n");
00160       return false;
00161     }
00162   }
00163 
00164   // compute input features
00165   if (feat_extract_->ComputeFeatures(char_samp, net_input_) == false) {
00166     fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): "
00167             "unable to compute features\n");
00168     return false;
00169   }
00170 
00171   if (char_net_ != NULL) {
00172     if (char_net_->FeedForward(net_input_, net_output_) == false) {
00173       fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): "
00174               "unable to run feed-forward\n");
00175       return false;
00176     }
00177   } else {
00178     return false;
00179   }
00180   Fold();
00181   return true;
00182 }
00183 
00187 int ConvNetCharClassifier::CharCost(CharSamp *char_samp) {
00188   if (RunNets(char_samp) == false) {
00189     return 0;
00190   }
00191   return CubeUtils::Prob2Cost(1.0f - net_output_[0]);
00192 }
00193 
00198 CharAltList *ConvNetCharClassifier::Classify(CharSamp *char_samp) {
00199   // run the needed nets
00200   if (RunNets(char_samp) == false) {
00201     return NULL;
00202   }
00203 
00204   int class_cnt = char_set_->ClassCount();
00205 
00206   // create an altlist
00207   CharAltList *alt_list = new CharAltList(char_set_, class_cnt);
00208   if (alt_list == NULL) {
00209     fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::Classify): "
00210             "returning emtpy CharAltList\n");
00211     return NULL;
00212   }
00213 
00214   for (int out = 1; out < class_cnt; out++) {
00215     int cost = CubeUtils::Prob2Cost(net_output_[out]);
00216     alt_list->Insert(out, cost);
00217   }
00218 
00219   return alt_list;
00220 }
00221 
00225 void ConvNetCharClassifier::SetNet(tesseract::NeuralNet *char_net) {
00226   if (char_net_ != NULL) {
00227     delete char_net_;
00228     char_net_ = NULL;
00229   }
00230   char_net_ = char_net;
00231 }
00232 
00237 bool ConvNetCharClassifier::LoadFoldingSets(const string &data_file_path,
00238                                             const string &lang,
00239                                             LangModel *lang_mod) {
00240   fold_set_cnt_ = 0;
00241   string fold_file_name;
00242   fold_file_name = data_file_path + lang;
00243   fold_file_name += ".cube.fold";
00244 
00245   // folding sets are optional
00246   FILE *fp = fopen(fold_file_name.c_str(), "rb");
00247   if (fp == NULL) {
00248     return true;
00249   }
00250   fclose(fp);
00251 
00252   string fold_sets_str;
00253   if (!CubeUtils::ReadFileToString(fold_file_name,
00254                                    &fold_sets_str)) {
00255     return false;
00256   }
00257 
00258   // split into lines
00259   vector<string> str_vec;
00260   CubeUtils::SplitStringUsing(fold_sets_str, "\r\n", &str_vec);
00261   fold_set_cnt_ = str_vec.size();
00262 
00263   fold_sets_ = new int *[fold_set_cnt_];
00264   if (fold_sets_ == NULL) {
00265     return false;
00266   }
00267   fold_set_len_ = new int[fold_set_cnt_];
00268   if (fold_set_len_ == NULL) {
00269     fold_set_cnt_ = 0;
00270     return false;
00271   }
00272 
00273   for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
00274     reinterpret_cast<TessLangModel *>(lang_mod)->RemoveInvalidCharacters(
00275         &str_vec[fold_set]);
00276 
00277     // if all or all but one character are invalid, invalidate this set
00278     if (str_vec[fold_set].length() <= 1) {
00279       fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): "
00280               "invalidating folding set %d\n", fold_set);
00281       fold_set_len_[fold_set] = 0;
00282       fold_sets_[fold_set] = NULL;
00283       continue;
00284     }
00285 
00286     string_32 str32;
00287     CubeUtils::UTF8ToUTF32(str_vec[fold_set].c_str(), &str32);
00288     fold_set_len_[fold_set] = str32.length();
00289     fold_sets_[fold_set] = new int[fold_set_len_[fold_set]];
00290     if (fold_sets_[fold_set] == NULL) {
00291       fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadFoldingSets): "
00292               "could not allocate folding set\n");
00293       fold_set_cnt_ = fold_set;
00294       return false;
00295     }
00296     for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
00297       fold_sets_[fold_set][ch] = char_set_->ClassID(str32[ch]);
00298     }
00299   }
00300   return true;
00301 }
00302 
00306 bool ConvNetCharClassifier::Init(const string &data_file_path,
00307                                  const string &lang,
00308                                  LangModel *lang_mod) {
00309   if (init_) {
00310     return true;
00311   }
00312 
00313   // load the nets if any. This function will return true if the net file
00314   // does not exist. But will fail if the net did not pass the sanity checks
00315   if (!LoadNets(data_file_path, lang)) {
00316     return false;
00317   }
00318 
00319   // load the folding sets if any. This function will return true if the
00320   // file does not exist. But will fail if the it did not pass the sanity checks
00321   if (!LoadFoldingSets(data_file_path, lang, lang_mod)) {
00322     return false;
00323   }
00324 
00325   init_ = true;
00326   return true;
00327 }
00328 
00334 bool ConvNetCharClassifier::LoadNets(const string &data_file_path,
00335                                      const string &lang) {
00336   string char_net_file;
00337 
00338   // add the lang identifier
00339   char_net_file = data_file_path + lang;
00340   char_net_file += ".cube.nn";
00341 
00342   // neural network is optional
00343   FILE *fp = fopen(char_net_file.c_str(), "rb");
00344   if (fp == NULL) {
00345     return true;
00346   }
00347   fclose(fp);
00348 
00349   // load main net
00350   char_net_ = tesseract::NeuralNet::FromFile(char_net_file);
00351   if (char_net_ == NULL) {
00352     fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): "
00353             "could not load %s\n", char_net_file.c_str());
00354     return false;
00355   }
00356 
00357   // validate net
00358   if (char_net_->in_cnt()!= feat_extract_->FeatureCnt()) {
00359     fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): "
00360             "could not validate net %s\n", char_net_file.c_str());
00361     return false;
00362   }
00363 
00364   // alloc net i/o buffers
00365   int feat_cnt = char_net_->in_cnt();
00366   int class_cnt = char_set_->ClassCount();
00367 
00368   if (char_net_->out_cnt() != class_cnt) {
00369     fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): "
00370             "output count (%d) and class count (%d) are not equal\n",
00371             char_net_->out_cnt(), class_cnt);
00372     return false;
00373   }
00374 
00375   // allocate i/p and o/p buffers if needed
00376   if (net_input_ == NULL) {
00377     net_input_ = new float[feat_cnt];
00378     if (net_input_ == NULL) {
00379       return false;
00380     }
00381 
00382     net_output_ = new float[class_cnt];
00383     if (net_output_ == NULL) {
00384       return false;
00385     }
00386   }
00387 
00388   return true;
00389 }
00390 }  // tesseract
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines