|
tesseract 3.04.01
|
00001 // Copyright 2010 Google Inc. All Rights Reserved. 00002 // Author: rays@google.com (Ray Smith) 00004 // File: mastertrainer.cpp 00005 // Description: Trainer to build the MasterClassifier. 00006 // Author: Ray Smith 00007 // Created: Wed Nov 03 18:10:01 PDT 2010 00008 // 00009 // (C) Copyright 2010, Google Inc. 00010 // Licensed under the Apache License, Version 2.0 (the "License"); 00011 // you may not use this file except in compliance with the License. 00012 // You may obtain a copy of the License at 00013 // http://www.apache.org/licenses/LICENSE-2.0 00014 // Unless required by applicable law or agreed to in writing, software 00015 // distributed under the License is distributed on an "AS IS" BASIS, 00016 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00017 // See the License for the specific language governing permissions and 00018 // limitations under the License. 00019 // 00021 00022 // Include automatically generated configuration file if running autoconf. 00023 #ifdef HAVE_CONFIG_H 00024 #include "config_auto.h" 00025 #endif 00026 00027 #include "mastertrainer.h" 00028 #include <math.h> 00029 #include <time.h> 00030 #include "allheaders.h" 00031 #include "boxread.h" 00032 #include "classify.h" 00033 #include "efio.h" 00034 #include "errorcounter.h" 00035 #include "featdefs.h" 00036 #include "sampleiterator.h" 00037 #include "shapeclassifier.h" 00038 #include "shapetable.h" 00039 #include "svmnode.h" 00040 00041 #include "scanutils.h" 00042 00043 namespace tesseract { 00044 00045 // Constants controlling clustering. With a low kMinClusteredShapes and a high 00046 // kMaxUnicharsPerCluster, then kFontMergeDistance is the only limiting factor. 00047 // Min number of shapes in the output. 00048 const int kMinClusteredShapes = 1; 00049 // Max number of unichars in any individual cluster. 00050 const int kMaxUnicharsPerCluster = 2000; 00051 // Mean font distance below which to merge fonts and unichars. 00052 const float kFontMergeDistance = 0.025; 00053 00054 MasterTrainer::MasterTrainer(NormalizationMode norm_mode, 00055 bool shape_analysis, 00056 bool replicate_samples, 00057 int debug_level) 00058 : norm_mode_(norm_mode), samples_(fontinfo_table_), 00059 junk_samples_(fontinfo_table_), verify_samples_(fontinfo_table_), 00060 charsetsize_(0), 00061 enable_shape_anaylsis_(shape_analysis), 00062 enable_replication_(replicate_samples), 00063 fragments_(NULL), prev_unichar_id_(-1), debug_level_(debug_level) { 00064 } 00065 00066 MasterTrainer::~MasterTrainer() { 00067 delete [] fragments_; 00068 for (int p = 0; p < page_images_.size(); ++p) 00069 pixDestroy(&page_images_[p]); 00070 } 00071 00072 // WARNING! Serialize/DeSerialize are only partial, providing 00073 // enough data to get the samples back and display them. 00074 // Writes to the given file. Returns false in case of error. 00075 bool MasterTrainer::Serialize(FILE* fp) const { 00076 if (fwrite(&norm_mode_, sizeof(norm_mode_), 1, fp) != 1) return false; 00077 if (!unicharset_.save_to_file(fp)) return false; 00078 if (!feature_space_.Serialize(fp)) return false; 00079 if (!samples_.Serialize(fp)) return false; 00080 if (!junk_samples_.Serialize(fp)) return false; 00081 if (!verify_samples_.Serialize(fp)) return false; 00082 if (!master_shapes_.Serialize(fp)) return false; 00083 if (!flat_shapes_.Serialize(fp)) return false; 00084 if (!fontinfo_table_.Serialize(fp)) return false; 00085 if (!xheights_.Serialize(fp)) return false; 00086 return true; 00087 } 00088 00089 // Reads from the given file. Returns false in case of error. 00090 // If swap is true, assumes a big/little-endian swap is needed. 00091 bool MasterTrainer::DeSerialize(bool swap, FILE* fp) { 00092 if (fread(&norm_mode_, sizeof(norm_mode_), 1, fp) != 1) return false; 00093 if (swap) { 00094 ReverseN(&norm_mode_, sizeof(norm_mode_)); 00095 } 00096 if (!unicharset_.load_from_file(fp)) return false; 00097 charsetsize_ = unicharset_.size(); 00098 if (!feature_space_.DeSerialize(swap, fp)) return false; 00099 feature_map_.Init(feature_space_); 00100 if (!samples_.DeSerialize(swap, fp)) return false; 00101 if (!junk_samples_.DeSerialize(swap, fp)) return false; 00102 if (!verify_samples_.DeSerialize(swap, fp)) return false; 00103 if (!master_shapes_.DeSerialize(swap, fp)) return false; 00104 if (!flat_shapes_.DeSerialize(swap, fp)) return false; 00105 if (!fontinfo_table_.DeSerialize(swap, fp)) return false; 00106 if (!xheights_.DeSerialize(swap, fp)) return false; 00107 return true; 00108 } 00109 00110 // Load an initial unicharset, or set one up if the file cannot be read. 00111 void MasterTrainer::LoadUnicharset(const char* filename) { 00112 if (!unicharset_.load_from_file(filename)) { 00113 tprintf("Failed to load unicharset from file %s\n" 00114 "Building unicharset for training from scratch...\n", 00115 filename); 00116 unicharset_.clear(); 00117 UNICHARSET initialized; 00118 // Add special characters, as they were removed by the clear, but the 00119 // default constructor puts them in. 00120 unicharset_.AppendOtherUnicharset(initialized); 00121 } 00122 charsetsize_ = unicharset_.size(); 00123 delete [] fragments_; 00124 fragments_ = new int[charsetsize_]; 00125 memset(fragments_, 0, sizeof(*fragments_) * charsetsize_); 00126 samples_.LoadUnicharset(filename); 00127 junk_samples_.LoadUnicharset(filename); 00128 verify_samples_.LoadUnicharset(filename); 00129 } 00130 00131 // Reads the samples and their features from the given .tr format file, 00132 // adding them to the trainer with the font_id from the content of the file. 00133 // See mftraining.cpp for a description of the file format. 00134 // If verification, then these are verification samples, not training. 00135 void MasterTrainer::ReadTrainingSamples(const char* page_name, 00136 const FEATURE_DEFS_STRUCT& feature_defs, 00137 bool verification) { 00138 char buffer[2048]; 00139 int int_feature_type = ShortNameToFeatureType(feature_defs, kIntFeatureType); 00140 int micro_feature_type = ShortNameToFeatureType(feature_defs, 00141 kMicroFeatureType); 00142 int cn_feature_type = ShortNameToFeatureType(feature_defs, kCNFeatureType); 00143 int geo_feature_type = ShortNameToFeatureType(feature_defs, kGeoFeatureType); 00144 00145 FILE* fp = Efopen(page_name, "rb"); 00146 if (fp == NULL) { 00147 tprintf("Failed to open tr file: %s\n", page_name); 00148 return; 00149 } 00150 tr_filenames_.push_back(STRING(page_name)); 00151 while (fgets(buffer, sizeof(buffer), fp) != NULL) { 00152 if (buffer[0] == '\n') 00153 continue; 00154 00155 char* space = strchr(buffer, ' '); 00156 if (space == NULL) { 00157 tprintf("Bad format in tr file, reading fontname, unichar\n"); 00158 continue; 00159 } 00160 *space++ = '\0'; 00161 int font_id = GetFontInfoId(buffer); 00162 if (font_id < 0) font_id = 0; 00163 int page_number; 00164 STRING unichar; 00165 TBOX bounding_box; 00166 if (!ParseBoxFileStr(space, &page_number, &unichar, &bounding_box)) { 00167 tprintf("Bad format in tr file, reading box coords\n"); 00168 continue; 00169 } 00170 CHAR_DESC char_desc = ReadCharDescription(feature_defs, fp); 00171 TrainingSample* sample = new TrainingSample; 00172 sample->set_font_id(font_id); 00173 sample->set_page_num(page_number + page_images_.size()); 00174 sample->set_bounding_box(bounding_box); 00175 sample->ExtractCharDesc(int_feature_type, micro_feature_type, 00176 cn_feature_type, geo_feature_type, char_desc); 00177 AddSample(verification, unichar.string(), sample); 00178 FreeCharDescription(char_desc); 00179 } 00180 charsetsize_ = unicharset_.size(); 00181 fclose(fp); 00182 } 00183 00184 // Adds the given single sample to the trainer, setting the classid 00185 // appropriately from the given unichar_str. 00186 void MasterTrainer::AddSample(bool verification, const char* unichar, 00187 TrainingSample* sample) { 00188 if (verification) { 00189 verify_samples_.AddSample(unichar, sample); 00190 prev_unichar_id_ = -1; 00191 } else if (unicharset_.contains_unichar(unichar)) { 00192 if (prev_unichar_id_ >= 0) 00193 fragments_[prev_unichar_id_] = -1; 00194 prev_unichar_id_ = samples_.AddSample(unichar, sample); 00195 if (flat_shapes_.FindShape(prev_unichar_id_, sample->font_id()) < 0) 00196 flat_shapes_.AddShape(prev_unichar_id_, sample->font_id()); 00197 } else { 00198 int junk_id = junk_samples_.AddSample(unichar, sample); 00199 if (prev_unichar_id_ >= 0) { 00200 CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(unichar); 00201 if (frag != NULL && frag->is_natural()) { 00202 if (fragments_[prev_unichar_id_] == 0) 00203 fragments_[prev_unichar_id_] = junk_id; 00204 else if (fragments_[prev_unichar_id_] != junk_id) 00205 fragments_[prev_unichar_id_] = -1; 00206 } 00207 delete frag; 00208 } 00209 prev_unichar_id_ = -1; 00210 } 00211 } 00212 00213 // Loads all pages from the given tif filename and append to page_images_. 00214 // Must be called after ReadTrainingSamples, as the current number of images 00215 // is used as an offset for page numbers in the samples. 00216 void MasterTrainer::LoadPageImages(const char* filename) { 00217 int page; 00218 Pix* pix; 00219 for (page = 0; (pix = pixReadTiff(filename, page)) != NULL; ++page) { 00220 page_images_.push_back(pix); 00221 } 00222 tprintf("Loaded %d page images from %s\n", page, filename); 00223 } 00224 00225 // Cleans up the samples after initial load from the tr files, and prior to 00226 // saving the MasterTrainer: 00227 // Remaps fragmented chars if running shape anaylsis. 00228 // Sets up the samples appropriately for class/fontwise access. 00229 // Deletes outlier samples. 00230 void MasterTrainer::PostLoadCleanup() { 00231 if (debug_level_ > 0) 00232 tprintf("PostLoadCleanup...\n"); 00233 if (enable_shape_anaylsis_) 00234 ReplaceFragmentedSamples(); 00235 SampleIterator sample_it; 00236 sample_it.Init(NULL, NULL, true, &verify_samples_); 00237 sample_it.NormalizeSamples(); 00238 verify_samples_.OrganizeByFontAndClass(); 00239 00240 samples_.IndexFeatures(feature_space_); 00241 // TODO(rays) DeleteOutliers is currently turned off to prove NOP-ness 00242 // against current training. 00243 // samples_.DeleteOutliers(feature_space_, debug_level_ > 0); 00244 samples_.OrganizeByFontAndClass(); 00245 if (debug_level_ > 0) 00246 tprintf("ComputeCanonicalSamples...\n"); 00247 samples_.ComputeCanonicalSamples(feature_map_, debug_level_ > 0); 00248 } 00249 00250 // Gets the samples ready for training. Use after both 00251 // ReadTrainingSamples+PostLoadCleanup or DeSerialize. 00252 // Re-indexes the features and computes canonical and cloud features. 00253 void MasterTrainer::PreTrainingSetup() { 00254 if (debug_level_ > 0) 00255 tprintf("PreTrainingSetup...\n"); 00256 samples_.IndexFeatures(feature_space_); 00257 samples_.ComputeCanonicalFeatures(); 00258 if (debug_level_ > 0) 00259 tprintf("ComputeCloudFeatures...\n"); 00260 samples_.ComputeCloudFeatures(feature_space_.Size()); 00261 } 00262 00263 // Sets up the master_shapes_ table, which tells which fonts should stay 00264 // together until they get to a leaf node classifier. 00265 void MasterTrainer::SetupMasterShapes() { 00266 tprintf("Building master shape table\n"); 00267 int num_fonts = samples_.NumFonts(); 00268 00269 ShapeTable char_shapes_begin_fragment(samples_.unicharset()); 00270 ShapeTable char_shapes_end_fragment(samples_.unicharset()); 00271 ShapeTable char_shapes(samples_.unicharset()); 00272 for (int c = 0; c < samples_.charsetsize(); ++c) { 00273 ShapeTable shapes(samples_.unicharset()); 00274 for (int f = 0; f < num_fonts; ++f) { 00275 if (samples_.NumClassSamples(f, c, true) > 0) 00276 shapes.AddShape(c, f); 00277 } 00278 ClusterShapes(kMinClusteredShapes, 1, kFontMergeDistance, &shapes); 00279 00280 const CHAR_FRAGMENT *fragment = samples_.unicharset().get_fragment(c); 00281 00282 if (fragment == NULL) 00283 char_shapes.AppendMasterShapes(shapes, NULL); 00284 else if (fragment->is_beginning()) 00285 char_shapes_begin_fragment.AppendMasterShapes(shapes, NULL); 00286 else if (fragment->is_ending()) 00287 char_shapes_end_fragment.AppendMasterShapes(shapes, NULL); 00288 else 00289 char_shapes.AppendMasterShapes(shapes, NULL); 00290 } 00291 ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster, 00292 kFontMergeDistance, &char_shapes_begin_fragment); 00293 char_shapes.AppendMasterShapes(char_shapes_begin_fragment, NULL); 00294 ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster, 00295 kFontMergeDistance, &char_shapes_end_fragment); 00296 char_shapes.AppendMasterShapes(char_shapes_end_fragment, NULL); 00297 ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster, 00298 kFontMergeDistance, &char_shapes); 00299 master_shapes_.AppendMasterShapes(char_shapes, NULL); 00300 tprintf("Master shape_table:%s\n", master_shapes_.SummaryStr().string()); 00301 } 00302 00303 // Adds the junk_samples_ to the main samples_ set. Junk samples are initially 00304 // fragments and n-grams (all incorrectly segmented characters). 00305 // Various training functions may result in incorrectly segmented characters 00306 // being added to the unicharset of the main samples, perhaps because they 00307 // form a "radical" decomposition of some (Indic) grapheme, or because they 00308 // just look the same as a real character (like rn/m) 00309 // This function moves all the junk samples, to the main samples_ set, but 00310 // desirable junk, being any sample for which the unichar already exists in 00311 // the samples_ unicharset gets the unichar-ids re-indexed to match, but 00312 // anything else gets re-marked as unichar_id 0 (space character) to identify 00313 // it as junk to the error counter. 00314 void MasterTrainer::IncludeJunk() { 00315 // Get ids of fragments in junk_samples_ that replace the dead chars. 00316 const UNICHARSET& junk_set = junk_samples_.unicharset(); 00317 const UNICHARSET& sample_set = samples_.unicharset(); 00318 int num_junks = junk_samples_.num_samples(); 00319 tprintf("Moving %d junk samples to master sample set.\n", num_junks); 00320 for (int s = 0; s < num_junks; ++s) { 00321 TrainingSample* sample = junk_samples_.mutable_sample(s); 00322 int junk_id = sample->class_id(); 00323 const char* junk_utf8 = junk_set.id_to_unichar(junk_id); 00324 int sample_id = sample_set.unichar_to_id(junk_utf8); 00325 if (sample_id == INVALID_UNICHAR_ID) 00326 sample_id = 0; 00327 sample->set_class_id(sample_id); 00328 junk_samples_.extract_sample(s); 00329 samples_.AddSample(sample_id, sample); 00330 } 00331 junk_samples_.DeleteDeadSamples(); 00332 samples_.OrganizeByFontAndClass(); 00333 } 00334 00335 // Replicates the samples and perturbs them if the enable_replication_ flag 00336 // is set. MUST be used after the last call to OrganizeByFontAndClass on 00337 // the training samples, ie after IncludeJunk if it is going to be used, as 00338 // OrganizeByFontAndClass will eat the replicated samples into the regular 00339 // samples. 00340 void MasterTrainer::ReplicateAndRandomizeSamplesIfRequired() { 00341 if (enable_replication_) { 00342 if (debug_level_ > 0) 00343 tprintf("ReplicateAndRandomize...\n"); 00344 verify_samples_.ReplicateAndRandomizeSamples(); 00345 samples_.ReplicateAndRandomizeSamples(); 00346 samples_.IndexFeatures(feature_space_); 00347 } 00348 } 00349 00350 // Loads the basic font properties file into fontinfo_table_. 00351 // Returns false on failure. 00352 bool MasterTrainer::LoadFontInfo(const char* filename) { 00353 FILE* fp = fopen(filename, "rb"); 00354 if (fp == NULL) { 00355 fprintf(stderr, "Failed to load font_properties from %s\n", filename); 00356 return false; 00357 } 00358 int italic, bold, fixed, serif, fraktur; 00359 while (!feof(fp)) { 00360 FontInfo fontinfo; 00361 char* font_name = new char[1024]; 00362 fontinfo.name = font_name; 00363 fontinfo.properties = 0; 00364 fontinfo.universal_id = 0; 00365 if (tfscanf(fp, "%1024s %i %i %i %i %i\n", font_name, 00366 &italic, &bold, &fixed, &serif, &fraktur) != 6) 00367 continue; 00368 fontinfo.properties = 00369 (italic << 0) + 00370 (bold << 1) + 00371 (fixed << 2) + 00372 (serif << 3) + 00373 (fraktur << 4); 00374 if (!fontinfo_table_.contains(fontinfo)) { 00375 fontinfo_table_.push_back(fontinfo); 00376 } 00377 } 00378 fclose(fp); 00379 return true; 00380 } 00381 00382 // Loads the xheight font properties file into xheights_. 00383 // Returns false on failure. 00384 bool MasterTrainer::LoadXHeights(const char* filename) { 00385 tprintf("fontinfo table is of size %d\n", fontinfo_table_.size()); 00386 xheights_.init_to_size(fontinfo_table_.size(), -1); 00387 if (filename == NULL) return true; 00388 FILE *f = fopen(filename, "rb"); 00389 if (f == NULL) { 00390 fprintf(stderr, "Failed to load font xheights from %s\n", filename); 00391 return false; 00392 } 00393 tprintf("Reading x-heights from %s ...\n", filename); 00394 FontInfo fontinfo; 00395 fontinfo.properties = 0; // Not used to lookup in the table. 00396 fontinfo.universal_id = 0; 00397 char buffer[1024]; 00398 int xht; 00399 int total_xheight = 0; 00400 int xheight_count = 0; 00401 while (!feof(f)) { 00402 if (tfscanf(f, "%1023s %d\n", buffer, &xht) != 2) 00403 continue; 00404 buffer[1023] = '\0'; 00405 fontinfo.name = buffer; 00406 if (!fontinfo_table_.contains(fontinfo)) continue; 00407 int fontinfo_id = fontinfo_table_.get_index(fontinfo); 00408 xheights_[fontinfo_id] = xht; 00409 total_xheight += xht; 00410 ++xheight_count; 00411 } 00412 if (xheight_count == 0) { 00413 fprintf(stderr, "No valid xheights in %s!\n", filename); 00414 fclose(f); 00415 return false; 00416 } 00417 int mean_xheight = DivRounded(total_xheight, xheight_count); 00418 for (int i = 0; i < fontinfo_table_.size(); ++i) { 00419 if (xheights_[i] < 0) 00420 xheights_[i] = mean_xheight; 00421 } 00422 fclose(f); 00423 return true; 00424 } // LoadXHeights 00425 00426 // Reads spacing stats from filename and adds them to fontinfo_table. 00427 bool MasterTrainer::AddSpacingInfo(const char *filename) { 00428 FILE* fontinfo_file = fopen(filename, "rb"); 00429 if (fontinfo_file == NULL) 00430 return true; // We silently ignore missing files! 00431 // Find the fontinfo_id. 00432 int fontinfo_id = GetBestMatchingFontInfoId(filename); 00433 if (fontinfo_id < 0) { 00434 tprintf("No font found matching fontinfo filename %s\n", filename); 00435 fclose(fontinfo_file); 00436 return false; 00437 } 00438 tprintf("Reading spacing from %s for font %d...\n", filename, fontinfo_id); 00439 // TODO(rays) scale should probably be a double, but keep as an int for now 00440 // to duplicate current behavior. 00441 int scale = kBlnXHeight / xheights_[fontinfo_id]; 00442 int num_unichars; 00443 char uch[UNICHAR_LEN]; 00444 char kerned_uch[UNICHAR_LEN]; 00445 int x_gap, x_gap_before, x_gap_after, num_kerned; 00446 ASSERT_HOST(tfscanf(fontinfo_file, "%d\n", &num_unichars) == 1); 00447 FontInfo *fi = &fontinfo_table_.get(fontinfo_id); 00448 fi->init_spacing(unicharset_.size()); 00449 FontSpacingInfo *spacing = NULL; 00450 for (int l = 0; l < num_unichars; ++l) { 00451 if (tfscanf(fontinfo_file, "%s %d %d %d", 00452 uch, &x_gap_before, &x_gap_after, &num_kerned) != 4) { 00453 tprintf("Bad format of font spacing file %s\n", filename); 00454 fclose(fontinfo_file); 00455 return false; 00456 } 00457 bool valid = unicharset_.contains_unichar(uch); 00458 if (valid) { 00459 spacing = new FontSpacingInfo(); 00460 spacing->x_gap_before = static_cast<inT16>(x_gap_before * scale); 00461 spacing->x_gap_after = static_cast<inT16>(x_gap_after * scale); 00462 } 00463 for (int k = 0; k < num_kerned; ++k) { 00464 if (tfscanf(fontinfo_file, "%s %d", kerned_uch, &x_gap) != 2) { 00465 tprintf("Bad format of font spacing file %s\n", filename); 00466 fclose(fontinfo_file); 00467 delete spacing; 00468 return false; 00469 } 00470 if (!valid || !unicharset_.contains_unichar(kerned_uch)) continue; 00471 spacing->kerned_unichar_ids.push_back( 00472 unicharset_.unichar_to_id(kerned_uch)); 00473 spacing->kerned_x_gaps.push_back(static_cast<inT16>(x_gap * scale)); 00474 } 00475 if (valid) fi->add_spacing(unicharset_.unichar_to_id(uch), spacing); 00476 } 00477 fclose(fontinfo_file); 00478 return true; 00479 } 00480 00481 // Returns the font id corresponding to the given font name. 00482 // Returns -1 if the font cannot be found. 00483 int MasterTrainer::GetFontInfoId(const char* font_name) { 00484 FontInfo fontinfo; 00485 // We are only borrowing the string, so it is OK to const cast it. 00486 fontinfo.name = const_cast<char*>(font_name); 00487 fontinfo.properties = 0; // Not used to lookup in the table 00488 fontinfo.universal_id = 0; 00489 return fontinfo_table_.get_index(fontinfo); 00490 } 00491 // Returns the font_id of the closest matching font name to the given 00492 // filename. It is assumed that a substring of the filename will match 00493 // one of the fonts. If more than one is matched, the longest is returned. 00494 int MasterTrainer::GetBestMatchingFontInfoId(const char* filename) { 00495 int fontinfo_id = -1; 00496 int best_len = 0; 00497 for (int f = 0; f < fontinfo_table_.size(); ++f) { 00498 if (strstr(filename, fontinfo_table_.get(f).name) != NULL) { 00499 int len = strlen(fontinfo_table_.get(f).name); 00500 // Use the longest matching length in case a substring of a font matched. 00501 if (len > best_len) { 00502 best_len = len; 00503 fontinfo_id = f; 00504 } 00505 } 00506 } 00507 return fontinfo_id; 00508 } 00509 00510 // Sets up a flat shapetable with one shape per class/font combination. 00511 void MasterTrainer::SetupFlatShapeTable(ShapeTable* shape_table) { 00512 // To exactly mimic the results of the previous implementation, the shapes 00513 // must be clustered in order the fonts arrived, and reverse order of the 00514 // characters within each font. 00515 // Get a list of the fonts in the order they appeared. 00516 GenericVector<int> active_fonts; 00517 int num_shapes = flat_shapes_.NumShapes(); 00518 for (int s = 0; s < num_shapes; ++s) { 00519 int font = flat_shapes_.GetShape(s)[0].font_ids[0]; 00520 int f = 0; 00521 for (f = 0; f < active_fonts.size(); ++f) { 00522 if (active_fonts[f] == font) 00523 break; 00524 } 00525 if (f == active_fonts.size()) 00526 active_fonts.push_back(font); 00527 } 00528 // For each font in order, add all the shapes with that font in reverse order. 00529 int num_fonts = active_fonts.size(); 00530 for (int f = 0; f < num_fonts; ++f) { 00531 for (int s = num_shapes - 1; s >= 0; --s) { 00532 int font = flat_shapes_.GetShape(s)[0].font_ids[0]; 00533 if (font == active_fonts[f]) { 00534 shape_table->AddShape(flat_shapes_.GetShape(s)); 00535 } 00536 } 00537 } 00538 } 00539 00540 // Sets up a Clusterer for mftraining on a single shape_id. 00541 // Call FreeClusterer on the return value after use. 00542 CLUSTERER* MasterTrainer::SetupForClustering( 00543 const ShapeTable& shape_table, 00544 const FEATURE_DEFS_STRUCT& feature_defs, 00545 int shape_id, 00546 int* num_samples) { 00547 00548 int desc_index = ShortNameToFeatureType(feature_defs, kMicroFeatureType); 00549 int num_params = feature_defs.FeatureDesc[desc_index]->NumParams; 00550 ASSERT_HOST(num_params == MFCount); 00551 CLUSTERER* clusterer = MakeClusterer( 00552 num_params, feature_defs.FeatureDesc[desc_index]->ParamDesc); 00553 00554 // We want to iterate over the samples of just the one shape. 00555 IndexMapBiDi shape_map; 00556 shape_map.Init(shape_table.NumShapes(), false); 00557 shape_map.SetMap(shape_id, true); 00558 shape_map.Setup(); 00559 // Reverse the order of the samples to match the previous behavior. 00560 GenericVector<const TrainingSample*> sample_ptrs; 00561 SampleIterator it; 00562 it.Init(&shape_map, &shape_table, false, &samples_); 00563 for (it.Begin(); !it.AtEnd(); it.Next()) { 00564 sample_ptrs.push_back(&it.GetSample()); 00565 } 00566 int sample_id = 0; 00567 for (int i = sample_ptrs.size() - 1; i >= 0; --i) { 00568 const TrainingSample* sample = sample_ptrs[i]; 00569 int num_features = sample->num_micro_features(); 00570 for (int f = 0; f < num_features; ++f) 00571 MakeSample(clusterer, sample->micro_features()[f], sample_id); 00572 ++sample_id; 00573 } 00574 *num_samples = sample_id; 00575 return clusterer; 00576 } 00577 00578 // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp 00579 // to the given inttemp_file, and the corresponding pffmtable. 00580 // The unicharset is the original encoding of graphemes, and shape_set should 00581 // match the size of the shape_table, and may possibly be totally fake. 00582 void MasterTrainer::WriteInttempAndPFFMTable(const UNICHARSET& unicharset, 00583 const UNICHARSET& shape_set, 00584 const ShapeTable& shape_table, 00585 CLASS_STRUCT* float_classes, 00586 const char* inttemp_file, 00587 const char* pffmtable_file) { 00588 tesseract::Classify *classify = new tesseract::Classify(); 00589 // Move the fontinfo table to classify. 00590 fontinfo_table_.MoveTo(&classify->get_fontinfo_table()); 00591 INT_TEMPLATES int_templates = classify->CreateIntTemplates(float_classes, 00592 shape_set); 00593 FILE* fp = fopen(inttemp_file, "wb"); 00594 classify->WriteIntTemplates(fp, int_templates, shape_set); 00595 fclose(fp); 00596 // Now write pffmtable. This is complicated by the fact that the adaptive 00597 // classifier still wants one indexed by unichar-id, but the static 00598 // classifier needs one indexed by its shape class id. 00599 // We put the shapetable_cutoffs in a GenericVector, and compute the 00600 // unicharset cutoffs along the way. 00601 GenericVector<uinT16> shapetable_cutoffs; 00602 GenericVector<uinT16> unichar_cutoffs; 00603 for (int c = 0; c < unicharset.size(); ++c) 00604 unichar_cutoffs.push_back(0); 00605 /* then write out each class */ 00606 for (int i = 0; i < int_templates->NumClasses; ++i) { 00607 INT_CLASS Class = ClassForClassId(int_templates, i); 00608 // Todo: Test with min instead of max 00609 // int MaxLength = LengthForConfigId(Class, 0); 00610 uinT16 max_length = 0; 00611 for (int config_id = 0; config_id < Class->NumConfigs; config_id++) { 00612 // Todo: Test with min instead of max 00613 // if (LengthForConfigId (Class, config_id) < MaxLength) 00614 uinT16 length = Class->ConfigLengths[config_id]; 00615 if (length > max_length) 00616 max_length = Class->ConfigLengths[config_id]; 00617 int shape_id = float_classes[i].font_set.get(config_id); 00618 const Shape& shape = shape_table.GetShape(shape_id); 00619 for (int c = 0; c < shape.size(); ++c) { 00620 int unichar_id = shape[c].unichar_id; 00621 if (length > unichar_cutoffs[unichar_id]) 00622 unichar_cutoffs[unichar_id] = length; 00623 } 00624 } 00625 shapetable_cutoffs.push_back(max_length); 00626 } 00627 fp = fopen(pffmtable_file, "wb"); 00628 shapetable_cutoffs.Serialize(fp); 00629 for (int c = 0; c < unicharset.size(); ++c) { 00630 const char *unichar = unicharset.id_to_unichar(c); 00631 if (strcmp(unichar, " ") == 0) { 00632 unichar = "NULL"; 00633 } 00634 fprintf(fp, "%s %d\n", unichar, unichar_cutoffs[c]); 00635 } 00636 fclose(fp); 00637 free_int_templates(int_templates); 00638 delete classify; 00639 } 00640 00641 // Generate debug output relating to the canonical distance between the 00642 // two given UTF8 grapheme strings. 00643 void MasterTrainer::DebugCanonical(const char* unichar_str1, 00644 const char* unichar_str2) { 00645 int class_id1 = unicharset_.unichar_to_id(unichar_str1); 00646 int class_id2 = unicharset_.unichar_to_id(unichar_str2); 00647 if (class_id2 == INVALID_UNICHAR_ID) 00648 class_id2 = class_id1; 00649 if (class_id1 == INVALID_UNICHAR_ID) { 00650 tprintf("No unicharset entry found for %s\n", unichar_str1); 00651 return; 00652 } else { 00653 tprintf("Font ambiguities for unichar %d = %s and %d = %s\n", 00654 class_id1, unichar_str1, class_id2, unichar_str2); 00655 } 00656 int num_fonts = samples_.NumFonts(); 00657 const IntFeatureMap& feature_map = feature_map_; 00658 // Iterate the fonts to get the similarity with other fonst of the same 00659 // class. 00660 tprintf(" "); 00661 for (int f = 0; f < num_fonts; ++f) { 00662 if (samples_.NumClassSamples(f, class_id2, false) == 0) 00663 continue; 00664 tprintf("%6d", f); 00665 } 00666 tprintf("\n"); 00667 for (int f1 = 0; f1 < num_fonts; ++f1) { 00668 // Map the features of the canonical_sample. 00669 if (samples_.NumClassSamples(f1, class_id1, false) == 0) 00670 continue; 00671 tprintf("%4d ", f1); 00672 for (int f2 = 0; f2 < num_fonts; ++f2) { 00673 if (samples_.NumClassSamples(f2, class_id2, false) == 0) 00674 continue; 00675 float dist = samples_.ClusterDistance(f1, class_id1, f2, class_id2, 00676 feature_map); 00677 tprintf(" %5.3f", dist); 00678 } 00679 tprintf("\n"); 00680 } 00681 // Build a fake ShapeTable containing all the sample types. 00682 ShapeTable shapes(unicharset_); 00683 for (int f = 0; f < num_fonts; ++f) { 00684 if (samples_.NumClassSamples(f, class_id1, true) > 0) 00685 shapes.AddShape(class_id1, f); 00686 if (class_id1 != class_id2 && 00687 samples_.NumClassSamples(f, class_id2, true) > 0) 00688 shapes.AddShape(class_id2, f); 00689 } 00690 } 00691 00692 #ifndef GRAPHICS_DISABLED 00693 // Debugging for cloud/canonical features. 00694 // Displays a Features window containing: 00695 // If unichar_str2 is in the unicharset, and canonical_font is non-negative, 00696 // displays the canonical features of the char/font combination in red. 00697 // If unichar_str1 is in the unicharset, and cloud_font is non-negative, 00698 // displays the cloud feature of the char/font combination in green. 00699 // The canonical features are drawn first to show which ones have no 00700 // matches in the cloud features. 00701 // Until the features window is destroyed, each click in the features window 00702 // will display the samples that have that feature in a separate window. 00703 void MasterTrainer::DisplaySamples(const char* unichar_str1, int cloud_font, 00704 const char* unichar_str2, 00705 int canonical_font) { 00706 const IntFeatureMap& feature_map = feature_map_; 00707 const IntFeatureSpace& feature_space = feature_map.feature_space(); 00708 ScrollView* f_window = CreateFeatureSpaceWindow("Features", 100, 500); 00709 ClearFeatureSpaceWindow(norm_mode_ == NM_BASELINE ? baseline : character, 00710 f_window); 00711 int class_id2 = samples_.unicharset().unichar_to_id(unichar_str2); 00712 if (class_id2 != INVALID_UNICHAR_ID && canonical_font >= 0) { 00713 const TrainingSample* sample = samples_.GetCanonicalSample(canonical_font, 00714 class_id2); 00715 for (int f = 0; f < sample->num_features(); ++f) { 00716 RenderIntFeature(f_window, &sample->features()[f], ScrollView::RED); 00717 } 00718 } 00719 int class_id1 = samples_.unicharset().unichar_to_id(unichar_str1); 00720 if (class_id1 != INVALID_UNICHAR_ID && cloud_font >= 0) { 00721 const BitVector& cloud = samples_.GetCloudFeatures(cloud_font, class_id1); 00722 for (int f = 0; f < cloud.size(); ++f) { 00723 if (cloud[f]) { 00724 INT_FEATURE_STRUCT feature = 00725 feature_map.InverseIndexFeature(f); 00726 RenderIntFeature(f_window, &feature, ScrollView::GREEN); 00727 } 00728 } 00729 } 00730 f_window->Update(); 00731 ScrollView* s_window = CreateFeatureSpaceWindow("Samples", 100, 500); 00732 SVEventType ev_type; 00733 do { 00734 SVEvent* ev; 00735 // Wait until a click or popup event. 00736 ev = f_window->AwaitEvent(SVET_ANY); 00737 ev_type = ev->type; 00738 if (ev_type == SVET_CLICK) { 00739 int feature_index = feature_space.XYToFeatureIndex(ev->x, ev->y); 00740 if (feature_index >= 0) { 00741 // Iterate samples and display those with the feature. 00742 Shape shape; 00743 shape.AddToShape(class_id1, cloud_font); 00744 s_window->Clear(); 00745 samples_.DisplaySamplesWithFeature(feature_index, shape, 00746 feature_space, ScrollView::GREEN, 00747 s_window); 00748 s_window->Update(); 00749 } 00750 } 00751 delete ev; 00752 } while (ev_type != SVET_DESTROY); 00753 } 00754 #endif // GRAPHICS_DISABLED 00755 00756 void MasterTrainer::TestClassifierVOld(bool replicate_samples, 00757 ShapeClassifier* test_classifier, 00758 ShapeClassifier* old_classifier) { 00759 SampleIterator sample_it; 00760 sample_it.Init(NULL, NULL, replicate_samples, &samples_); 00761 ErrorCounter::DebugNewErrors(test_classifier, old_classifier, 00762 CT_UNICHAR_TOPN_ERR, fontinfo_table_, 00763 page_images_, &sample_it); 00764 } 00765 00766 // Tests the given test_classifier on the internal samples. 00767 // See TestClassifier for details. 00768 void MasterTrainer::TestClassifierOnSamples(CountTypes error_mode, 00769 int report_level, 00770 bool replicate_samples, 00771 ShapeClassifier* test_classifier, 00772 STRING* report_string) { 00773 TestClassifier(error_mode, report_level, replicate_samples, &samples_, 00774 test_classifier, report_string); 00775 } 00776 00777 // Tests the given test_classifier on the given samples. 00778 // error_mode indicates what counts as an error. 00779 // report_levels: 00780 // 0 = no output. 00781 // 1 = bottom-line error rate. 00782 // 2 = bottom-line error rate + time. 00783 // 3 = font-level error rate + time. 00784 // 4 = list of all errors + short classifier debug output on 16 errors. 00785 // 5 = list of all errors + short classifier debug output on 25 errors. 00786 // If replicate_samples is true, then the test is run on an extended test 00787 // sample including replicated and systematically perturbed samples. 00788 // If report_string is non-NULL, a summary of the results for each font 00789 // is appended to the report_string. 00790 double MasterTrainer::TestClassifier(CountTypes error_mode, 00791 int report_level, 00792 bool replicate_samples, 00793 TrainingSampleSet* samples, 00794 ShapeClassifier* test_classifier, 00795 STRING* report_string) { 00796 SampleIterator sample_it; 00797 sample_it.Init(NULL, NULL, replicate_samples, samples); 00798 if (report_level > 0) { 00799 int num_samples = 0; 00800 for (sample_it.Begin(); !sample_it.AtEnd(); sample_it.Next()) 00801 ++num_samples; 00802 tprintf("Iterator has charset size of %d/%d, %d shapes, %d samples\n", 00803 sample_it.SparseCharsetSize(), sample_it.CompactCharsetSize(), 00804 test_classifier->GetShapeTable()->NumShapes(), num_samples); 00805 tprintf("Testing %sREPLICATED:\n", replicate_samples ? "" : "NON-"); 00806 } 00807 double unichar_error = 0.0; 00808 ErrorCounter::ComputeErrorRate(test_classifier, report_level, 00809 error_mode, fontinfo_table_, 00810 page_images_, &sample_it, &unichar_error, 00811 NULL, report_string); 00812 return unichar_error; 00813 } 00814 00815 // Returns the average (in some sense) distance between the two given 00816 // shapes, which may contain multiple fonts and/or unichars. 00817 float MasterTrainer::ShapeDistance(const ShapeTable& shapes, int s1, int s2) { 00818 const IntFeatureMap& feature_map = feature_map_; 00819 const Shape& shape1 = shapes.GetShape(s1); 00820 const Shape& shape2 = shapes.GetShape(s2); 00821 int num_chars1 = shape1.size(); 00822 int num_chars2 = shape2.size(); 00823 float dist_sum = 0.0f; 00824 int dist_count = 0; 00825 if (num_chars1 > 1 || num_chars2 > 1) { 00826 // In the multi-char case try to optimize the calculation by computing 00827 // distances between characters of matching font where possible. 00828 for (int c1 = 0; c1 < num_chars1; ++c1) { 00829 for (int c2 = 0; c2 < num_chars2; ++c2) { 00830 dist_sum += samples_.UnicharDistance(shape1[c1], shape2[c2], 00831 true, feature_map); 00832 ++dist_count; 00833 } 00834 } 00835 } else { 00836 // In the single unichar case, there is little alternative, but to compute 00837 // the squared-order distance between pairs of fonts. 00838 dist_sum = samples_.UnicharDistance(shape1[0], shape2[0], 00839 false, feature_map); 00840 ++dist_count; 00841 } 00842 return dist_sum / dist_count; 00843 } 00844 00845 // Replaces samples that are always fragmented with the corresponding 00846 // fragment samples. 00847 void MasterTrainer::ReplaceFragmentedSamples() { 00848 if (fragments_ == NULL) return; 00849 // Remove samples that are replaced by fragments. Each class that was 00850 // always naturally fragmented should be replaced by its fragments. 00851 int num_samples = samples_.num_samples(); 00852 for (int s = 0; s < num_samples; ++s) { 00853 TrainingSample* sample = samples_.mutable_sample(s); 00854 if (fragments_[sample->class_id()] > 0) 00855 samples_.KillSample(sample); 00856 } 00857 samples_.DeleteDeadSamples(); 00858 00859 // Get ids of fragments in junk_samples_ that replace the dead chars. 00860 const UNICHARSET& frag_set = junk_samples_.unicharset(); 00861 #if 0 00862 // TODO(rays) The original idea was to replace only graphemes that were 00863 // always naturally fragmented, but that left a lot of the Indic graphemes 00864 // out. Determine whether we can go back to that idea now that spacing 00865 // is fixed in the training images, or whether this code is obsolete. 00866 bool* good_junk = new bool[frag_set.size()]; 00867 memset(good_junk, 0, sizeof(*good_junk) * frag_set.size()); 00868 for (int dead_ch = 1; dead_ch < unicharset_.size(); ++dead_ch) { 00869 int frag_ch = fragments_[dead_ch]; 00870 if (frag_ch <= 0) continue; 00871 const char* frag_utf8 = frag_set.id_to_unichar(frag_ch); 00872 CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(frag_utf8); 00873 // Mark the chars for all parts of the fragment as good in good_junk. 00874 for (int part = 0; part < frag->get_total(); ++part) { 00875 frag->set_pos(part); 00876 int good_ch = frag_set.unichar_to_id(frag->to_string().string()); 00877 if (good_ch != INVALID_UNICHAR_ID) 00878 good_junk[good_ch] = true; // We want this one. 00879 } 00880 } 00881 #endif 00882 // For now just use all the junk that was from natural fragments. 00883 // Get samples of fragments in junk_samples_ that replace the dead chars. 00884 int num_junks = junk_samples_.num_samples(); 00885 for (int s = 0; s < num_junks; ++s) { 00886 TrainingSample* sample = junk_samples_.mutable_sample(s); 00887 int junk_id = sample->class_id(); 00888 const char* frag_utf8 = frag_set.id_to_unichar(junk_id); 00889 CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(frag_utf8); 00890 if (frag != NULL && frag->is_natural()) { 00891 junk_samples_.extract_sample(s); 00892 samples_.AddSample(frag_set.id_to_unichar(junk_id), sample); 00893 } 00894 } 00895 junk_samples_.DeleteDeadSamples(); 00896 junk_samples_.OrganizeByFontAndClass(); 00897 samples_.OrganizeByFontAndClass(); 00898 unicharset_.clear(); 00899 unicharset_.AppendOtherUnicharset(samples_.unicharset()); 00900 // delete [] good_junk; 00901 // Fragments_ no longer needed? 00902 delete [] fragments_; 00903 fragments_ = NULL; 00904 } 00905 00906 // Runs a hierarchical agglomerative clustering to merge shapes in the given 00907 // shape_table, while satisfying the given constraints: 00908 // * End with at least min_shapes left in shape_table, 00909 // * No shape shall have more than max_shape_unichars in it, 00910 // * Don't merge shapes where the distance between them exceeds max_dist. 00911 const float kInfiniteDist = 999.0f; 00912 void MasterTrainer::ClusterShapes(int min_shapes, int max_shape_unichars, 00913 float max_dist, ShapeTable* shapes) { 00914 int num_shapes = shapes->NumShapes(); 00915 int max_merges = num_shapes - min_shapes; 00916 GenericVector<ShapeDist>* shape_dists = 00917 new GenericVector<ShapeDist>[num_shapes]; 00918 float min_dist = kInfiniteDist; 00919 int min_s1 = 0; 00920 int min_s2 = 0; 00921 tprintf("Computing shape distances..."); 00922 for (int s1 = 0; s1 < num_shapes; ++s1) { 00923 for (int s2 = s1 + 1; s2 < num_shapes; ++s2) { 00924 ShapeDist dist(s1, s2, ShapeDistance(*shapes, s1, s2)); 00925 shape_dists[s1].push_back(dist); 00926 if (dist.distance < min_dist) { 00927 min_dist = dist.distance; 00928 min_s1 = s1; 00929 min_s2 = s2; 00930 } 00931 } 00932 tprintf(" %d", s1); 00933 } 00934 tprintf("\n"); 00935 int num_merged = 0; 00936 while (num_merged < max_merges && min_dist < max_dist) { 00937 tprintf("Distance = %f: ", min_dist); 00938 int num_unichars = shapes->MergedUnicharCount(min_s1, min_s2); 00939 shape_dists[min_s1][min_s2 - min_s1 - 1].distance = kInfiniteDist; 00940 if (num_unichars > max_shape_unichars) { 00941 tprintf("Merge of %d and %d with %d would exceed max of %d unichars\n", 00942 min_s1, min_s2, num_unichars, max_shape_unichars); 00943 } else { 00944 shapes->MergeShapes(min_s1, min_s2); 00945 shape_dists[min_s2].clear(); 00946 ++num_merged; 00947 00948 for (int s = 0; s < min_s1; ++s) { 00949 if (!shape_dists[s].empty()) { 00950 shape_dists[s][min_s1 - s - 1].distance = 00951 ShapeDistance(*shapes, s, min_s1); 00952 shape_dists[s][min_s2 - s -1].distance = kInfiniteDist; 00953 } 00954 } 00955 for (int s2 = min_s1 + 1; s2 < num_shapes; ++s2) { 00956 if (shape_dists[min_s1][s2 - min_s1 - 1].distance < kInfiniteDist) 00957 shape_dists[min_s1][s2 - min_s1 - 1].distance = 00958 ShapeDistance(*shapes, min_s1, s2); 00959 } 00960 for (int s = min_s1 + 1; s < min_s2; ++s) { 00961 if (!shape_dists[s].empty()) { 00962 shape_dists[s][min_s2 - s - 1].distance = kInfiniteDist; 00963 } 00964 } 00965 } 00966 min_dist = kInfiniteDist; 00967 for (int s1 = 0; s1 < num_shapes; ++s1) { 00968 for (int i = 0; i < shape_dists[s1].size(); ++i) { 00969 if (shape_dists[s1][i].distance < min_dist) { 00970 min_dist = shape_dists[s1][i].distance; 00971 min_s1 = s1; 00972 min_s2 = s1 + 1 + i; 00973 } 00974 } 00975 } 00976 } 00977 tprintf("Stopped with %d merged, min dist %f\n", num_merged, min_dist); 00978 delete [] shape_dists; 00979 if (debug_level_ > 1) { 00980 for (int s1 = 0; s1 < num_shapes; ++s1) { 00981 if (shapes->MasterDestinationIndex(s1) == s1) { 00982 tprintf("Master shape:%s\n", shapes->DebugStr(s1).string()); 00983 } 00984 } 00985 } 00986 } 00987 00988 00989 } // namespace tesseract.