|
tesseract 3.04.01
|
00001 /********************************************************************** 00002 * File: charclassifier.cpp 00003 * Description: Implementation of Convolutional-NeuralNet Character Classifier 00004 * Author: Ahmad Abdulkader 00005 * Created: 2007 00006 * 00007 * (C) Copyright 2008, Google Inc. 00008 ** Licensed under the Apache License, Version 2.0 (the "License"); 00009 ** you may not use this file except in compliance with the License. 00010 ** You may obtain a copy of the License at 00011 ** http://www.apache.org/licenses/LICENSE-2.0 00012 ** Unless required by applicable law or agreed to in writing, software 00013 ** distributed under the License is distributed on an "AS IS" BASIS, 00014 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00015 ** See the License for the specific language governing permissions and 00016 ** limitations under the License. 00017 * 00018 **********************************************************************/ 00019 00020 #include <algorithm> 00021 #include <stdio.h> 00022 #include <stdlib.h> 00023 #include <string> 00024 #include <vector> 00025 #include <wctype.h> 00026 00027 #include "char_set.h" 00028 #include "classifier_base.h" 00029 #include "const.h" 00030 #include "conv_net_classifier.h" 00031 #include "cube_utils.h" 00032 #include "feature_base.h" 00033 #include "feature_bmp.h" 00034 #include "tess_lang_model.h" 00035 00036 namespace tesseract { 00037 00038 ConvNetCharClassifier::ConvNetCharClassifier(CharSet *char_set, 00039 TuningParams *params, 00040 FeatureBase *feat_extract) 00041 : CharClassifier(char_set, params, feat_extract) { 00042 char_net_ = NULL; 00043 net_input_ = NULL; 00044 net_output_ = NULL; 00045 } 00046 00047 ConvNetCharClassifier::~ConvNetCharClassifier() { 00048 if (char_net_ != NULL) { 00049 delete char_net_; 00050 char_net_ = NULL; 00051 } 00052 00053 if (net_input_ != NULL) { 00054 delete []net_input_; 00055 net_input_ = NULL; 00056 } 00057 00058 if (net_output_ != NULL) { 00059 delete []net_output_; 00060 net_output_ = NULL; 00061 } 00062 } 00063 00069 bool ConvNetCharClassifier::Train(CharSamp *char_samp, int ClassID) { 00070 return false; 00071 } 00072 00078 bool ConvNetCharClassifier::SetLearnParam(char *var_name, float val) { 00079 // TODO(ahmadab): implementation of parameter initializing. 00080 return false; 00081 } 00082 00086 void ConvNetCharClassifier::Fold() { 00087 // in case insensitive mode 00088 if (case_sensitive_ == false) { 00089 int class_cnt = char_set_->ClassCount(); 00090 // fold case 00091 for (int class_id = 0; class_id < class_cnt; class_id++) { 00092 // get class string 00093 const char_32 *str32 = char_set_->ClassString(class_id); 00094 // get the upper case form of the string 00095 string_32 upper_form32 = str32; 00096 for (int ch = 0; ch < upper_form32.length(); ch++) { 00097 if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) { 00098 upper_form32[ch] = towupper(upper_form32[ch]); 00099 } 00100 } 00101 00102 // find out the upperform class-id if any 00103 int upper_class_id = 00104 char_set_->ClassID(reinterpret_cast<const char_32 *>( 00105 upper_form32.c_str())); 00106 if (upper_class_id != -1 && class_id != upper_class_id) { 00107 float max_out = MAX(net_output_[class_id], net_output_[upper_class_id]); 00108 net_output_[class_id] = max_out; 00109 net_output_[upper_class_id] = max_out; 00110 } 00111 } 00112 } 00113 00114 // The folding sets specify how groups of classes should be folded 00115 // Folding involved assigning a min-activation to all the members 00116 // of the folding set. The min-activation is a fraction of the max-activation 00117 // of the members of the folding set 00118 for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) { 00119 if (fold_set_len_[fold_set] == 0) 00120 continue; 00121 float max_prob = net_output_[fold_sets_[fold_set][0]]; 00122 for (int ch = 1; ch < fold_set_len_[fold_set]; ch++) { 00123 if (net_output_[fold_sets_[fold_set][ch]] > max_prob) { 00124 max_prob = net_output_[fold_sets_[fold_set][ch]]; 00125 } 00126 } 00127 for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) { 00128 net_output_[fold_sets_[fold_set][ch]] = MAX(max_prob * kFoldingRatio, 00129 net_output_[fold_sets_[fold_set][ch]]); 00130 } 00131 } 00132 } 00133 00138 bool ConvNetCharClassifier::RunNets(CharSamp *char_samp) { 00139 if (char_net_ == NULL) { 00140 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00141 "NeuralNet is NULL\n"); 00142 return false; 00143 } 00144 int feat_cnt = char_net_->in_cnt(); 00145 int class_cnt = char_set_->ClassCount(); 00146 00147 // allocate i/p and o/p buffers if needed 00148 if (net_input_ == NULL) { 00149 net_input_ = new float[feat_cnt]; 00150 if (net_input_ == NULL) { 00151 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00152 "unable to allocate memory for input nodes\n"); 00153 return false; 00154 } 00155 00156 net_output_ = new float[class_cnt]; 00157 if (net_output_ == NULL) { 00158 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00159 "unable to allocate memory for output nodes\n"); 00160 return false; 00161 } 00162 } 00163 00164 // compute input features 00165 if (feat_extract_->ComputeFeatures(char_samp, net_input_) == false) { 00166 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00167 "unable to compute features\n"); 00168 return false; 00169 } 00170 00171 if (char_net_ != NULL) { 00172 if (char_net_->FeedForward(net_input_, net_output_) == false) { 00173 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00174 "unable to run feed-forward\n"); 00175 return false; 00176 } 00177 } else { 00178 return false; 00179 } 00180 Fold(); 00181 return true; 00182 } 00183 00187 int ConvNetCharClassifier::CharCost(CharSamp *char_samp) { 00188 if (RunNets(char_samp) == false) { 00189 return 0; 00190 } 00191 return CubeUtils::Prob2Cost(1.0f - net_output_[0]); 00192 } 00193 00198 CharAltList *ConvNetCharClassifier::Classify(CharSamp *char_samp) { 00199 // run the needed nets 00200 if (RunNets(char_samp) == false) { 00201 return NULL; 00202 } 00203 00204 int class_cnt = char_set_->ClassCount(); 00205 00206 // create an altlist 00207 CharAltList *alt_list = new CharAltList(char_set_, class_cnt); 00208 if (alt_list == NULL) { 00209 fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::Classify): " 00210 "returning emtpy CharAltList\n"); 00211 return NULL; 00212 } 00213 00214 for (int out = 1; out < class_cnt; out++) { 00215 int cost = CubeUtils::Prob2Cost(net_output_[out]); 00216 alt_list->Insert(out, cost); 00217 } 00218 00219 return alt_list; 00220 } 00221 00225 void ConvNetCharClassifier::SetNet(tesseract::NeuralNet *char_net) { 00226 if (char_net_ != NULL) { 00227 delete char_net_; 00228 char_net_ = NULL; 00229 } 00230 char_net_ = char_net; 00231 } 00232 00237 bool ConvNetCharClassifier::LoadFoldingSets(const string &data_file_path, 00238 const string &lang, 00239 LangModel *lang_mod) { 00240 fold_set_cnt_ = 0; 00241 string fold_file_name; 00242 fold_file_name = data_file_path + lang; 00243 fold_file_name += ".cube.fold"; 00244 00245 // folding sets are optional 00246 FILE *fp = fopen(fold_file_name.c_str(), "rb"); 00247 if (fp == NULL) { 00248 return true; 00249 } 00250 fclose(fp); 00251 00252 string fold_sets_str; 00253 if (!CubeUtils::ReadFileToString(fold_file_name, 00254 &fold_sets_str)) { 00255 return false; 00256 } 00257 00258 // split into lines 00259 vector<string> str_vec; 00260 CubeUtils::SplitStringUsing(fold_sets_str, "\r\n", &str_vec); 00261 fold_set_cnt_ = str_vec.size(); 00262 00263 fold_sets_ = new int *[fold_set_cnt_]; 00264 if (fold_sets_ == NULL) { 00265 return false; 00266 } 00267 fold_set_len_ = new int[fold_set_cnt_]; 00268 if (fold_set_len_ == NULL) { 00269 fold_set_cnt_ = 0; 00270 return false; 00271 } 00272 00273 for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) { 00274 reinterpret_cast<TessLangModel *>(lang_mod)->RemoveInvalidCharacters( 00275 &str_vec[fold_set]); 00276 00277 // if all or all but one character are invalid, invalidate this set 00278 if (str_vec[fold_set].length() <= 1) { 00279 fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): " 00280 "invalidating folding set %d\n", fold_set); 00281 fold_set_len_[fold_set] = 0; 00282 fold_sets_[fold_set] = NULL; 00283 continue; 00284 } 00285 00286 string_32 str32; 00287 CubeUtils::UTF8ToUTF32(str_vec[fold_set].c_str(), &str32); 00288 fold_set_len_[fold_set] = str32.length(); 00289 fold_sets_[fold_set] = new int[fold_set_len_[fold_set]]; 00290 if (fold_sets_[fold_set] == NULL) { 00291 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadFoldingSets): " 00292 "could not allocate folding set\n"); 00293 fold_set_cnt_ = fold_set; 00294 return false; 00295 } 00296 for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) { 00297 fold_sets_[fold_set][ch] = char_set_->ClassID(str32[ch]); 00298 } 00299 } 00300 return true; 00301 } 00302 00306 bool ConvNetCharClassifier::Init(const string &data_file_path, 00307 const string &lang, 00308 LangModel *lang_mod) { 00309 if (init_) { 00310 return true; 00311 } 00312 00313 // load the nets if any. This function will return true if the net file 00314 // does not exist. But will fail if the net did not pass the sanity checks 00315 if (!LoadNets(data_file_path, lang)) { 00316 return false; 00317 } 00318 00319 // load the folding sets if any. This function will return true if the 00320 // file does not exist. But will fail if the it did not pass the sanity checks 00321 if (!LoadFoldingSets(data_file_path, lang, lang_mod)) { 00322 return false; 00323 } 00324 00325 init_ = true; 00326 return true; 00327 } 00328 00334 bool ConvNetCharClassifier::LoadNets(const string &data_file_path, 00335 const string &lang) { 00336 string char_net_file; 00337 00338 // add the lang identifier 00339 char_net_file = data_file_path + lang; 00340 char_net_file += ".cube.nn"; 00341 00342 // neural network is optional 00343 FILE *fp = fopen(char_net_file.c_str(), "rb"); 00344 if (fp == NULL) { 00345 return true; 00346 } 00347 fclose(fp); 00348 00349 // load main net 00350 char_net_ = tesseract::NeuralNet::FromFile(char_net_file); 00351 if (char_net_ == NULL) { 00352 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): " 00353 "could not load %s\n", char_net_file.c_str()); 00354 return false; 00355 } 00356 00357 // validate net 00358 if (char_net_->in_cnt()!= feat_extract_->FeatureCnt()) { 00359 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): " 00360 "could not validate net %s\n", char_net_file.c_str()); 00361 return false; 00362 } 00363 00364 // alloc net i/o buffers 00365 int feat_cnt = char_net_->in_cnt(); 00366 int class_cnt = char_set_->ClassCount(); 00367 00368 if (char_net_->out_cnt() != class_cnt) { 00369 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): " 00370 "output count (%d) and class count (%d) are not equal\n", 00371 char_net_->out_cnt(), class_cnt); 00372 return false; 00373 } 00374 00375 // allocate i/p and o/p buffers if needed 00376 if (net_input_ == NULL) { 00377 net_input_ = new float[feat_cnt]; 00378 if (net_input_ == NULL) { 00379 return false; 00380 } 00381 00382 net_output_ = new float[class_cnt]; 00383 if (net_output_ == NULL) { 00384 return false; 00385 } 00386 } 00387 00388 return true; 00389 } 00390 } // tesseract