tesseract 3.04.01

classify/mastertrainer.cpp

Go to the documentation of this file.
00001 // Copyright 2010 Google Inc. All Rights Reserved.
00002 // Author: rays@google.com (Ray Smith)
00004 // File:        mastertrainer.cpp
00005 // Description: Trainer to build the MasterClassifier.
00006 // Author:      Ray Smith
00007 // Created:     Wed Nov 03 18:10:01 PDT 2010
00008 //
00009 // (C) Copyright 2010, Google Inc.
00010 // Licensed under the Apache License, Version 2.0 (the "License");
00011 // you may not use this file except in compliance with the License.
00012 // You may obtain a copy of the License at
00013 // http://www.apache.org/licenses/LICENSE-2.0
00014 // Unless required by applicable law or agreed to in writing, software
00015 // distributed under the License is distributed on an "AS IS" BASIS,
00016 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00017 // See the License for the specific language governing permissions and
00018 // limitations under the License.
00019 //
00021 
00022 // Include automatically generated configuration file if running autoconf.
00023 #ifdef HAVE_CONFIG_H
00024 #include "config_auto.h"
00025 #endif
00026 
00027 #include "mastertrainer.h"
00028 #include <math.h>
00029 #include <time.h>
00030 #include "allheaders.h"
00031 #include "boxread.h"
00032 #include "classify.h"
00033 #include "efio.h"
00034 #include "errorcounter.h"
00035 #include "featdefs.h"
00036 #include "sampleiterator.h"
00037 #include "shapeclassifier.h"
00038 #include "shapetable.h"
00039 #include "svmnode.h"
00040 
00041 #include "scanutils.h"
00042 
00043 namespace tesseract {
00044 
00045 // Constants controlling clustering. With a low kMinClusteredShapes and a high
00046 // kMaxUnicharsPerCluster, then kFontMergeDistance is the only limiting factor.
00047 // Min number of shapes in the output.
00048 const int kMinClusteredShapes = 1;
00049 // Max number of unichars in any individual cluster.
00050 const int kMaxUnicharsPerCluster = 2000;
00051 // Mean font distance below which to merge fonts and unichars.
00052 const float kFontMergeDistance = 0.025;
00053 
00054 MasterTrainer::MasterTrainer(NormalizationMode norm_mode,
00055                              bool shape_analysis,
00056                              bool replicate_samples,
00057                              int debug_level)
00058   : norm_mode_(norm_mode), samples_(fontinfo_table_),
00059     junk_samples_(fontinfo_table_), verify_samples_(fontinfo_table_),
00060     charsetsize_(0),
00061     enable_shape_anaylsis_(shape_analysis),
00062     enable_replication_(replicate_samples),
00063     fragments_(NULL), prev_unichar_id_(-1), debug_level_(debug_level) {
00064 }
00065 
00066 MasterTrainer::~MasterTrainer() {
00067   delete [] fragments_;
00068   for (int p = 0; p < page_images_.size(); ++p)
00069     pixDestroy(&page_images_[p]);
00070 }
00071 
00072 // WARNING! Serialize/DeSerialize are only partial, providing
00073 // enough data to get the samples back and display them.
00074 // Writes to the given file. Returns false in case of error.
00075 bool MasterTrainer::Serialize(FILE* fp) const {
00076   if (fwrite(&norm_mode_, sizeof(norm_mode_), 1, fp) != 1) return false;
00077   if (!unicharset_.save_to_file(fp)) return false;
00078   if (!feature_space_.Serialize(fp)) return false;
00079   if (!samples_.Serialize(fp)) return false;
00080   if (!junk_samples_.Serialize(fp)) return false;
00081   if (!verify_samples_.Serialize(fp)) return false;
00082   if (!master_shapes_.Serialize(fp)) return false;
00083   if (!flat_shapes_.Serialize(fp)) return false;
00084   if (!fontinfo_table_.Serialize(fp)) return false;
00085   if (!xheights_.Serialize(fp)) return false;
00086   return true;
00087 }
00088 
00089 // Reads from the given file. Returns false in case of error.
00090 // If swap is true, assumes a big/little-endian swap is needed.
00091 bool MasterTrainer::DeSerialize(bool swap, FILE* fp) {
00092   if (fread(&norm_mode_, sizeof(norm_mode_), 1, fp) != 1) return false;
00093   if (swap) {
00094     ReverseN(&norm_mode_, sizeof(norm_mode_));
00095   }
00096   if (!unicharset_.load_from_file(fp)) return false;
00097   charsetsize_ = unicharset_.size();
00098   if (!feature_space_.DeSerialize(swap, fp)) return false;
00099   feature_map_.Init(feature_space_);
00100   if (!samples_.DeSerialize(swap, fp)) return false;
00101   if (!junk_samples_.DeSerialize(swap, fp)) return false;
00102   if (!verify_samples_.DeSerialize(swap, fp)) return false;
00103   if (!master_shapes_.DeSerialize(swap, fp)) return false;
00104   if (!flat_shapes_.DeSerialize(swap, fp)) return false;
00105   if (!fontinfo_table_.DeSerialize(swap, fp)) return false;
00106   if (!xheights_.DeSerialize(swap, fp)) return false;
00107   return true;
00108 }
00109 
00110 // Load an initial unicharset, or set one up if the file cannot be read.
00111 void MasterTrainer::LoadUnicharset(const char* filename) {
00112   if (!unicharset_.load_from_file(filename)) {
00113     tprintf("Failed to load unicharset from file %s\n"
00114             "Building unicharset for training from scratch...\n",
00115             filename);
00116     unicharset_.clear();
00117     UNICHARSET initialized;
00118     // Add special characters, as they were removed by the clear, but the
00119     // default constructor puts them in.
00120     unicharset_.AppendOtherUnicharset(initialized);
00121   }
00122   charsetsize_ = unicharset_.size();
00123   delete [] fragments_;
00124   fragments_ = new int[charsetsize_];
00125   memset(fragments_, 0, sizeof(*fragments_) * charsetsize_);
00126   samples_.LoadUnicharset(filename);
00127   junk_samples_.LoadUnicharset(filename);
00128   verify_samples_.LoadUnicharset(filename);
00129 }
00130 
00131 // Reads the samples and their features from the given .tr format file,
00132 // adding them to the trainer with the font_id from the content of the file.
00133 // See mftraining.cpp for a description of the file format.
00134 // If verification, then these are verification samples, not training.
00135 void MasterTrainer::ReadTrainingSamples(const char* page_name,
00136                                         const FEATURE_DEFS_STRUCT& feature_defs,
00137                                         bool verification) {
00138   char buffer[2048];
00139   int int_feature_type = ShortNameToFeatureType(feature_defs, kIntFeatureType);
00140   int micro_feature_type = ShortNameToFeatureType(feature_defs,
00141                                                   kMicroFeatureType);
00142   int cn_feature_type = ShortNameToFeatureType(feature_defs, kCNFeatureType);
00143   int geo_feature_type = ShortNameToFeatureType(feature_defs, kGeoFeatureType);
00144 
00145   FILE* fp = Efopen(page_name, "rb");
00146   if (fp == NULL) {
00147     tprintf("Failed to open tr file: %s\n", page_name);
00148     return;
00149   }
00150   tr_filenames_.push_back(STRING(page_name));
00151   while (fgets(buffer, sizeof(buffer), fp) != NULL) {
00152     if (buffer[0] == '\n')
00153       continue;
00154 
00155     char* space = strchr(buffer, ' ');
00156     if (space == NULL) {
00157       tprintf("Bad format in tr file, reading fontname, unichar\n");
00158       continue;
00159     }
00160     *space++ = '\0';
00161     int font_id = GetFontInfoId(buffer);
00162     if (font_id < 0) font_id = 0;
00163     int page_number;
00164     STRING unichar;
00165     TBOX bounding_box;
00166     if (!ParseBoxFileStr(space, &page_number, &unichar, &bounding_box)) {
00167       tprintf("Bad format in tr file, reading box coords\n");
00168       continue;
00169     }
00170     CHAR_DESC char_desc = ReadCharDescription(feature_defs, fp);
00171     TrainingSample* sample = new TrainingSample;
00172     sample->set_font_id(font_id);
00173     sample->set_page_num(page_number + page_images_.size());
00174     sample->set_bounding_box(bounding_box);
00175     sample->ExtractCharDesc(int_feature_type, micro_feature_type,
00176                             cn_feature_type, geo_feature_type, char_desc);
00177     AddSample(verification, unichar.string(), sample);
00178     FreeCharDescription(char_desc);
00179   }
00180   charsetsize_ = unicharset_.size();
00181   fclose(fp);
00182 }
00183 
00184 // Adds the given single sample to the trainer, setting the classid
00185 // appropriately from the given unichar_str.
00186 void MasterTrainer::AddSample(bool verification, const char* unichar,
00187                               TrainingSample* sample) {
00188   if (verification) {
00189     verify_samples_.AddSample(unichar, sample);
00190     prev_unichar_id_ = -1;
00191   } else if (unicharset_.contains_unichar(unichar)) {
00192     if (prev_unichar_id_ >= 0)
00193       fragments_[prev_unichar_id_] = -1;
00194     prev_unichar_id_ = samples_.AddSample(unichar, sample);
00195     if (flat_shapes_.FindShape(prev_unichar_id_, sample->font_id()) < 0)
00196       flat_shapes_.AddShape(prev_unichar_id_, sample->font_id());
00197   } else {
00198     int junk_id = junk_samples_.AddSample(unichar, sample);
00199     if (prev_unichar_id_ >= 0) {
00200       CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(unichar);
00201       if (frag != NULL && frag->is_natural()) {
00202         if (fragments_[prev_unichar_id_] == 0)
00203           fragments_[prev_unichar_id_] = junk_id;
00204         else if (fragments_[prev_unichar_id_] != junk_id)
00205           fragments_[prev_unichar_id_] = -1;
00206       }
00207       delete frag;
00208     }
00209     prev_unichar_id_ = -1;
00210   }
00211 }
00212 
00213 // Loads all pages from the given tif filename and append to page_images_.
00214 // Must be called after ReadTrainingSamples, as the current number of images
00215 // is used as an offset for page numbers in the samples.
00216 void MasterTrainer::LoadPageImages(const char* filename) {
00217   int page;
00218   Pix* pix;
00219   for (page = 0; (pix = pixReadTiff(filename, page)) != NULL; ++page) {
00220     page_images_.push_back(pix);
00221   }
00222   tprintf("Loaded %d page images from %s\n", page, filename);
00223 }
00224 
00225 // Cleans up the samples after initial load from the tr files, and prior to
00226 // saving the MasterTrainer:
00227 // Remaps fragmented chars if running shape anaylsis.
00228 // Sets up the samples appropriately for class/fontwise access.
00229 // Deletes outlier samples.
00230 void MasterTrainer::PostLoadCleanup() {
00231   if (debug_level_ > 0)
00232     tprintf("PostLoadCleanup...\n");
00233   if (enable_shape_anaylsis_)
00234     ReplaceFragmentedSamples();
00235   SampleIterator sample_it;
00236   sample_it.Init(NULL, NULL, true, &verify_samples_);
00237   sample_it.NormalizeSamples();
00238   verify_samples_.OrganizeByFontAndClass();
00239 
00240   samples_.IndexFeatures(feature_space_);
00241   // TODO(rays) DeleteOutliers is currently turned off to prove NOP-ness
00242   // against current training.
00243   //  samples_.DeleteOutliers(feature_space_, debug_level_ > 0);
00244   samples_.OrganizeByFontAndClass();
00245   if (debug_level_ > 0)
00246     tprintf("ComputeCanonicalSamples...\n");
00247   samples_.ComputeCanonicalSamples(feature_map_, debug_level_ > 0);
00248 }
00249 
00250 // Gets the samples ready for training. Use after both
00251 // ReadTrainingSamples+PostLoadCleanup or DeSerialize.
00252 // Re-indexes the features and computes canonical and cloud features.
00253 void MasterTrainer::PreTrainingSetup() {
00254   if (debug_level_ > 0)
00255     tprintf("PreTrainingSetup...\n");
00256   samples_.IndexFeatures(feature_space_);
00257   samples_.ComputeCanonicalFeatures();
00258   if (debug_level_ > 0)
00259     tprintf("ComputeCloudFeatures...\n");
00260   samples_.ComputeCloudFeatures(feature_space_.Size());
00261 }
00262 
00263 // Sets up the master_shapes_ table, which tells which fonts should stay
00264 // together until they get to a leaf node classifier.
00265 void MasterTrainer::SetupMasterShapes() {
00266   tprintf("Building master shape table\n");
00267   int num_fonts = samples_.NumFonts();
00268 
00269   ShapeTable char_shapes_begin_fragment(samples_.unicharset());
00270   ShapeTable char_shapes_end_fragment(samples_.unicharset());
00271   ShapeTable char_shapes(samples_.unicharset());
00272   for (int c = 0; c < samples_.charsetsize(); ++c) {
00273     ShapeTable shapes(samples_.unicharset());
00274     for (int f = 0; f < num_fonts; ++f) {
00275       if (samples_.NumClassSamples(f, c, true) > 0)
00276         shapes.AddShape(c, f);
00277     }
00278     ClusterShapes(kMinClusteredShapes, 1, kFontMergeDistance, &shapes);
00279 
00280     const CHAR_FRAGMENT *fragment = samples_.unicharset().get_fragment(c);
00281 
00282     if (fragment == NULL)
00283       char_shapes.AppendMasterShapes(shapes, NULL);
00284     else if (fragment->is_beginning())
00285       char_shapes_begin_fragment.AppendMasterShapes(shapes, NULL);
00286     else if (fragment->is_ending())
00287       char_shapes_end_fragment.AppendMasterShapes(shapes, NULL);
00288     else
00289       char_shapes.AppendMasterShapes(shapes, NULL);
00290   }
00291   ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster,
00292                 kFontMergeDistance, &char_shapes_begin_fragment);
00293   char_shapes.AppendMasterShapes(char_shapes_begin_fragment, NULL);
00294   ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster,
00295                 kFontMergeDistance, &char_shapes_end_fragment);
00296   char_shapes.AppendMasterShapes(char_shapes_end_fragment, NULL);
00297   ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster,
00298                 kFontMergeDistance, &char_shapes);
00299   master_shapes_.AppendMasterShapes(char_shapes, NULL);
00300   tprintf("Master shape_table:%s\n", master_shapes_.SummaryStr().string());
00301 }
00302 
00303 // Adds the junk_samples_ to the main samples_ set. Junk samples are initially
00304 // fragments and n-grams (all incorrectly segmented characters).
00305 // Various training functions may result in incorrectly segmented characters
00306 // being added to the unicharset of the main samples, perhaps because they
00307 // form a "radical" decomposition of some (Indic) grapheme, or because they
00308 // just look the same as a real character (like rn/m)
00309 // This function moves all the junk samples, to the main samples_ set, but
00310 // desirable junk, being any sample for which the unichar already exists in
00311 // the samples_ unicharset gets the unichar-ids re-indexed to match, but
00312 // anything else gets re-marked as unichar_id 0 (space character) to identify
00313 // it as junk to the error counter.
00314 void MasterTrainer::IncludeJunk() {
00315   // Get ids of fragments in junk_samples_ that replace the dead chars.
00316   const UNICHARSET& junk_set = junk_samples_.unicharset();
00317   const UNICHARSET& sample_set = samples_.unicharset();
00318   int num_junks = junk_samples_.num_samples();
00319   tprintf("Moving %d junk samples to master sample set.\n", num_junks);
00320   for (int s = 0; s < num_junks; ++s) {
00321     TrainingSample* sample = junk_samples_.mutable_sample(s);
00322     int junk_id = sample->class_id();
00323     const char* junk_utf8 = junk_set.id_to_unichar(junk_id);
00324     int sample_id = sample_set.unichar_to_id(junk_utf8);
00325     if (sample_id == INVALID_UNICHAR_ID)
00326       sample_id = 0;
00327     sample->set_class_id(sample_id);
00328     junk_samples_.extract_sample(s);
00329     samples_.AddSample(sample_id, sample);
00330   }
00331   junk_samples_.DeleteDeadSamples();
00332   samples_.OrganizeByFontAndClass();
00333 }
00334 
00335 // Replicates the samples and perturbs them if the enable_replication_ flag
00336 // is set. MUST be used after the last call to OrganizeByFontAndClass on
00337 // the training samples, ie after IncludeJunk if it is going to be used, as
00338 // OrganizeByFontAndClass will eat the replicated samples into the regular
00339 // samples.
00340 void MasterTrainer::ReplicateAndRandomizeSamplesIfRequired() {
00341   if (enable_replication_) {
00342     if (debug_level_ > 0)
00343       tprintf("ReplicateAndRandomize...\n");
00344     verify_samples_.ReplicateAndRandomizeSamples();
00345     samples_.ReplicateAndRandomizeSamples();
00346     samples_.IndexFeatures(feature_space_);
00347   }
00348 }
00349 
00350 // Loads the basic font properties file into fontinfo_table_.
00351 // Returns false on failure.
00352 bool MasterTrainer::LoadFontInfo(const char* filename) {
00353   FILE* fp = fopen(filename, "rb");
00354   if (fp == NULL) {
00355     fprintf(stderr, "Failed to load font_properties from %s\n", filename);
00356     return false;
00357   }
00358   int italic, bold, fixed, serif, fraktur;
00359   while (!feof(fp)) {
00360     FontInfo fontinfo;
00361     char* font_name = new char[1024];
00362     fontinfo.name = font_name;
00363     fontinfo.properties = 0;
00364     fontinfo.universal_id = 0;
00365     if (tfscanf(fp, "%1024s %i %i %i %i %i\n", font_name,
00366                 &italic, &bold, &fixed, &serif, &fraktur) != 6)
00367       continue;
00368     fontinfo.properties =
00369         (italic << 0) +
00370         (bold << 1) +
00371         (fixed << 2) +
00372         (serif << 3) +
00373         (fraktur << 4);
00374     if (!fontinfo_table_.contains(fontinfo)) {
00375       fontinfo_table_.push_back(fontinfo);
00376     }
00377   }
00378   fclose(fp);
00379   return true;
00380 }
00381 
00382 // Loads the xheight font properties file into xheights_.
00383 // Returns false on failure.
00384 bool MasterTrainer::LoadXHeights(const char* filename) {
00385   tprintf("fontinfo table is of size %d\n", fontinfo_table_.size());
00386   xheights_.init_to_size(fontinfo_table_.size(), -1);
00387   if (filename == NULL) return true;
00388   FILE *f = fopen(filename, "rb");
00389   if (f == NULL) {
00390     fprintf(stderr, "Failed to load font xheights from %s\n", filename);
00391     return false;
00392   }
00393   tprintf("Reading x-heights from %s ...\n", filename);
00394   FontInfo fontinfo;
00395   fontinfo.properties = 0;  // Not used to lookup in the table.
00396   fontinfo.universal_id = 0;
00397   char buffer[1024];
00398   int xht;
00399   int total_xheight = 0;
00400   int xheight_count = 0;
00401   while (!feof(f)) {
00402     if (tfscanf(f, "%1023s %d\n", buffer, &xht) != 2)
00403       continue;
00404     buffer[1023] = '\0';
00405     fontinfo.name = buffer;
00406     if (!fontinfo_table_.contains(fontinfo)) continue;
00407     int fontinfo_id = fontinfo_table_.get_index(fontinfo);
00408     xheights_[fontinfo_id] = xht;
00409     total_xheight += xht;
00410     ++xheight_count;
00411   }
00412   if (xheight_count == 0) {
00413     fprintf(stderr, "No valid xheights in %s!\n", filename);
00414     fclose(f);
00415     return false;
00416   }
00417   int mean_xheight = DivRounded(total_xheight, xheight_count);
00418   for (int i = 0; i < fontinfo_table_.size(); ++i) {
00419     if (xheights_[i] < 0)
00420       xheights_[i] = mean_xheight;
00421   }
00422   fclose(f);
00423   return true;
00424 }  // LoadXHeights
00425 
00426 // Reads spacing stats from filename and adds them to fontinfo_table.
00427 bool MasterTrainer::AddSpacingInfo(const char *filename) {
00428   FILE* fontinfo_file = fopen(filename, "rb");
00429   if (fontinfo_file == NULL)
00430     return true;  // We silently ignore missing files!
00431   // Find the fontinfo_id.
00432   int fontinfo_id = GetBestMatchingFontInfoId(filename);
00433   if (fontinfo_id < 0) {
00434     tprintf("No font found matching fontinfo filename %s\n", filename);
00435     fclose(fontinfo_file);
00436     return false;
00437   }
00438   tprintf("Reading spacing from %s for font %d...\n", filename, fontinfo_id);
00439   // TODO(rays) scale should probably be a double, but keep as an int for now
00440   // to duplicate current behavior.
00441   int scale = kBlnXHeight / xheights_[fontinfo_id];
00442   int num_unichars;
00443   char uch[UNICHAR_LEN];
00444   char kerned_uch[UNICHAR_LEN];
00445   int x_gap, x_gap_before, x_gap_after, num_kerned;
00446   ASSERT_HOST(tfscanf(fontinfo_file, "%d\n", &num_unichars) == 1);
00447   FontInfo *fi = &fontinfo_table_.get(fontinfo_id);
00448   fi->init_spacing(unicharset_.size());
00449   FontSpacingInfo *spacing = NULL;
00450   for (int l = 0; l < num_unichars; ++l) {
00451     if (tfscanf(fontinfo_file, "%s %d %d %d",
00452                 uch, &x_gap_before, &x_gap_after, &num_kerned) != 4) {
00453       tprintf("Bad format of font spacing file %s\n", filename);
00454       fclose(fontinfo_file);
00455       return false;
00456     }
00457     bool valid = unicharset_.contains_unichar(uch);
00458     if (valid) {
00459       spacing = new FontSpacingInfo();
00460       spacing->x_gap_before = static_cast<inT16>(x_gap_before * scale);
00461       spacing->x_gap_after = static_cast<inT16>(x_gap_after * scale);
00462     }
00463     for (int k = 0; k < num_kerned; ++k) {
00464       if (tfscanf(fontinfo_file, "%s %d", kerned_uch, &x_gap) != 2) {
00465         tprintf("Bad format of font spacing file %s\n", filename);
00466         fclose(fontinfo_file);
00467         delete spacing;
00468         return false;
00469       }
00470       if (!valid || !unicharset_.contains_unichar(kerned_uch)) continue;
00471       spacing->kerned_unichar_ids.push_back(
00472           unicharset_.unichar_to_id(kerned_uch));
00473       spacing->kerned_x_gaps.push_back(static_cast<inT16>(x_gap * scale));
00474     }
00475     if (valid) fi->add_spacing(unicharset_.unichar_to_id(uch), spacing);
00476   }
00477   fclose(fontinfo_file);
00478   return true;
00479 }
00480 
00481 // Returns the font id corresponding to the given font name.
00482 // Returns -1 if the font cannot be found.
00483 int MasterTrainer::GetFontInfoId(const char* font_name) {
00484   FontInfo fontinfo;
00485   // We are only borrowing the string, so it is OK to const cast it.
00486   fontinfo.name = const_cast<char*>(font_name);
00487   fontinfo.properties = 0;  // Not used to lookup in the table
00488   fontinfo.universal_id = 0;
00489   return fontinfo_table_.get_index(fontinfo);
00490 }
00491 // Returns the font_id of the closest matching font name to the given
00492 // filename. It is assumed that a substring of the filename will match
00493 // one of the fonts. If more than one is matched, the longest is returned.
00494 int MasterTrainer::GetBestMatchingFontInfoId(const char* filename) {
00495   int fontinfo_id = -1;
00496   int best_len = 0;
00497   for (int f = 0; f < fontinfo_table_.size(); ++f) {
00498     if (strstr(filename, fontinfo_table_.get(f).name) != NULL) {
00499       int len = strlen(fontinfo_table_.get(f).name);
00500       // Use the longest matching length in case a substring of a font matched.
00501       if (len > best_len) {
00502         best_len = len;
00503         fontinfo_id = f;
00504       }
00505     }
00506   }
00507   return fontinfo_id;
00508 }
00509 
00510 // Sets up a flat shapetable with one shape per class/font combination.
00511 void MasterTrainer::SetupFlatShapeTable(ShapeTable* shape_table) {
00512   // To exactly mimic the results of the previous implementation, the shapes
00513   // must be clustered in order the fonts arrived, and reverse order of the
00514   // characters within each font.
00515   // Get a list of the fonts in the order they appeared.
00516   GenericVector<int> active_fonts;
00517   int num_shapes = flat_shapes_.NumShapes();
00518   for (int s = 0; s < num_shapes; ++s) {
00519     int font = flat_shapes_.GetShape(s)[0].font_ids[0];
00520     int f = 0;
00521     for (f = 0; f < active_fonts.size(); ++f) {
00522       if (active_fonts[f] == font)
00523         break;
00524     }
00525     if (f == active_fonts.size())
00526       active_fonts.push_back(font);
00527   }
00528   // For each font in order, add all the shapes with that font in reverse order.
00529   int num_fonts = active_fonts.size();
00530   for (int f = 0; f < num_fonts; ++f) {
00531     for (int s = num_shapes - 1; s >= 0; --s) {
00532       int font = flat_shapes_.GetShape(s)[0].font_ids[0];
00533       if (font == active_fonts[f]) {
00534         shape_table->AddShape(flat_shapes_.GetShape(s));
00535       }
00536     }
00537   }
00538 }
00539 
00540 // Sets up a Clusterer for mftraining on a single shape_id.
00541 // Call FreeClusterer on the return value after use.
00542 CLUSTERER* MasterTrainer::SetupForClustering(
00543     const ShapeTable& shape_table,
00544     const FEATURE_DEFS_STRUCT& feature_defs,
00545     int shape_id,
00546     int* num_samples) {
00547 
00548   int desc_index = ShortNameToFeatureType(feature_defs, kMicroFeatureType);
00549   int num_params = feature_defs.FeatureDesc[desc_index]->NumParams;
00550   ASSERT_HOST(num_params == MFCount);
00551   CLUSTERER* clusterer = MakeClusterer(
00552       num_params, feature_defs.FeatureDesc[desc_index]->ParamDesc);
00553 
00554   // We want to iterate over the samples of just the one shape.
00555   IndexMapBiDi shape_map;
00556   shape_map.Init(shape_table.NumShapes(), false);
00557   shape_map.SetMap(shape_id, true);
00558   shape_map.Setup();
00559   // Reverse the order of the samples to match the previous behavior.
00560   GenericVector<const TrainingSample*> sample_ptrs;
00561   SampleIterator it;
00562   it.Init(&shape_map, &shape_table, false, &samples_);
00563   for (it.Begin(); !it.AtEnd(); it.Next()) {
00564     sample_ptrs.push_back(&it.GetSample());
00565   }
00566   int sample_id = 0;
00567   for (int i = sample_ptrs.size() - 1; i >= 0; --i) {
00568     const TrainingSample* sample = sample_ptrs[i];
00569     int num_features = sample->num_micro_features();
00570     for (int f = 0; f < num_features; ++f)
00571       MakeSample(clusterer, sample->micro_features()[f], sample_id);
00572     ++sample_id;
00573   }
00574   *num_samples = sample_id;
00575   return clusterer;
00576 }
00577 
00578 // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp
00579 // to the given inttemp_file, and the corresponding pffmtable.
00580 // The unicharset is the original encoding of graphemes, and shape_set should
00581 // match the size of the shape_table, and may possibly be totally fake.
00582 void MasterTrainer::WriteInttempAndPFFMTable(const UNICHARSET& unicharset,
00583                                              const UNICHARSET& shape_set,
00584                                              const ShapeTable& shape_table,
00585                                              CLASS_STRUCT* float_classes,
00586                                              const char* inttemp_file,
00587                                              const char* pffmtable_file) {
00588   tesseract::Classify *classify = new tesseract::Classify();
00589   // Move the fontinfo table to classify.
00590   fontinfo_table_.MoveTo(&classify->get_fontinfo_table());
00591   INT_TEMPLATES int_templates = classify->CreateIntTemplates(float_classes,
00592                                                              shape_set);
00593   FILE* fp = fopen(inttemp_file, "wb");
00594   classify->WriteIntTemplates(fp, int_templates, shape_set);
00595   fclose(fp);
00596   // Now write pffmtable. This is complicated by the fact that the adaptive
00597   // classifier still wants one indexed by unichar-id, but the static
00598   // classifier needs one indexed by its shape class id.
00599   // We put the shapetable_cutoffs in a GenericVector, and compute the
00600   // unicharset cutoffs along the way.
00601   GenericVector<uinT16> shapetable_cutoffs;
00602   GenericVector<uinT16> unichar_cutoffs;
00603   for (int c = 0; c < unicharset.size(); ++c)
00604     unichar_cutoffs.push_back(0);
00605   /* then write out each class */
00606   for (int i = 0; i < int_templates->NumClasses; ++i) {
00607     INT_CLASS Class = ClassForClassId(int_templates, i);
00608     // Todo: Test with min instead of max
00609     // int MaxLength = LengthForConfigId(Class, 0);
00610     uinT16 max_length = 0;
00611     for (int config_id = 0; config_id < Class->NumConfigs; config_id++) {
00612       // Todo: Test with min instead of max
00613       // if (LengthForConfigId (Class, config_id) < MaxLength)
00614       uinT16 length = Class->ConfigLengths[config_id];
00615       if (length > max_length)
00616         max_length = Class->ConfigLengths[config_id];
00617       int shape_id = float_classes[i].font_set.get(config_id);
00618       const Shape& shape = shape_table.GetShape(shape_id);
00619       for (int c = 0; c < shape.size(); ++c) {
00620         int unichar_id = shape[c].unichar_id;
00621         if (length > unichar_cutoffs[unichar_id])
00622           unichar_cutoffs[unichar_id] = length;
00623       }
00624     }
00625     shapetable_cutoffs.push_back(max_length);
00626   }
00627   fp = fopen(pffmtable_file, "wb");
00628   shapetable_cutoffs.Serialize(fp);
00629   for (int c = 0; c < unicharset.size(); ++c) {
00630     const char *unichar = unicharset.id_to_unichar(c);
00631     if (strcmp(unichar, " ") == 0) {
00632       unichar = "NULL";
00633     }
00634     fprintf(fp, "%s %d\n", unichar, unichar_cutoffs[c]);
00635   }
00636   fclose(fp);
00637   free_int_templates(int_templates);
00638   delete classify;
00639 }
00640 
00641 // Generate debug output relating to the canonical distance between the
00642 // two given UTF8 grapheme strings.
00643 void MasterTrainer::DebugCanonical(const char* unichar_str1,
00644                                    const char* unichar_str2) {
00645   int class_id1 = unicharset_.unichar_to_id(unichar_str1);
00646   int class_id2 = unicharset_.unichar_to_id(unichar_str2);
00647   if (class_id2 == INVALID_UNICHAR_ID)
00648     class_id2 = class_id1;
00649   if (class_id1 == INVALID_UNICHAR_ID) {
00650     tprintf("No unicharset entry found for %s\n", unichar_str1);
00651     return;
00652   } else {
00653     tprintf("Font ambiguities for unichar %d = %s and %d = %s\n",
00654             class_id1, unichar_str1, class_id2, unichar_str2);
00655   }
00656   int num_fonts = samples_.NumFonts();
00657   const IntFeatureMap& feature_map = feature_map_;
00658   // Iterate the fonts to get the similarity with other fonst of the same
00659   // class.
00660   tprintf("      ");
00661   for (int f = 0; f < num_fonts; ++f) {
00662     if (samples_.NumClassSamples(f, class_id2, false) == 0)
00663       continue;
00664     tprintf("%6d", f);
00665   }
00666   tprintf("\n");
00667   for (int f1 = 0; f1 < num_fonts; ++f1) {
00668     // Map the features of the canonical_sample.
00669     if (samples_.NumClassSamples(f1, class_id1, false) == 0)
00670       continue;
00671     tprintf("%4d  ", f1);
00672     for (int f2 = 0; f2 < num_fonts; ++f2) {
00673       if (samples_.NumClassSamples(f2, class_id2, false) == 0)
00674         continue;
00675       float dist = samples_.ClusterDistance(f1, class_id1, f2, class_id2,
00676                                             feature_map);
00677       tprintf(" %5.3f", dist);
00678     }
00679     tprintf("\n");
00680   }
00681   // Build a fake ShapeTable containing all the sample types.
00682   ShapeTable shapes(unicharset_);
00683   for (int f = 0; f < num_fonts; ++f) {
00684     if (samples_.NumClassSamples(f, class_id1, true) > 0)
00685       shapes.AddShape(class_id1, f);
00686     if (class_id1 != class_id2 &&
00687         samples_.NumClassSamples(f, class_id2, true) > 0)
00688       shapes.AddShape(class_id2, f);
00689   }
00690 }
00691 
00692 #ifndef GRAPHICS_DISABLED
00693 // Debugging for cloud/canonical features.
00694 // Displays a Features window containing:
00695 // If unichar_str2 is in the unicharset, and canonical_font is non-negative,
00696 // displays the canonical features of the char/font combination in red.
00697 // If unichar_str1 is in the unicharset, and cloud_font is non-negative,
00698 // displays the cloud feature of the char/font combination in green.
00699 // The canonical features are drawn first to show which ones have no
00700 // matches in the cloud features.
00701 // Until the features window is destroyed, each click in the features window
00702 // will display the samples that have that feature in a separate window.
00703 void MasterTrainer::DisplaySamples(const char* unichar_str1, int cloud_font,
00704                                    const char* unichar_str2,
00705                                    int canonical_font) {
00706   const IntFeatureMap& feature_map = feature_map_;
00707   const IntFeatureSpace& feature_space = feature_map.feature_space();
00708   ScrollView* f_window = CreateFeatureSpaceWindow("Features", 100, 500);
00709   ClearFeatureSpaceWindow(norm_mode_ == NM_BASELINE ? baseline : character,
00710                           f_window);
00711   int class_id2 = samples_.unicharset().unichar_to_id(unichar_str2);
00712   if (class_id2 != INVALID_UNICHAR_ID && canonical_font >= 0) {
00713     const TrainingSample* sample = samples_.GetCanonicalSample(canonical_font,
00714                                                                class_id2);
00715     for (int f = 0; f < sample->num_features(); ++f) {
00716       RenderIntFeature(f_window, &sample->features()[f], ScrollView::RED);
00717     }
00718   }
00719   int class_id1 = samples_.unicharset().unichar_to_id(unichar_str1);
00720   if (class_id1 != INVALID_UNICHAR_ID && cloud_font >= 0) {
00721     const BitVector& cloud = samples_.GetCloudFeatures(cloud_font, class_id1);
00722     for (int f = 0; f < cloud.size(); ++f) {
00723       if (cloud[f]) {
00724         INT_FEATURE_STRUCT feature =
00725             feature_map.InverseIndexFeature(f);
00726         RenderIntFeature(f_window, &feature, ScrollView::GREEN);
00727       }
00728     }
00729   }
00730   f_window->Update();
00731   ScrollView* s_window = CreateFeatureSpaceWindow("Samples", 100, 500);
00732   SVEventType ev_type;
00733   do {
00734     SVEvent* ev;
00735     // Wait until a click or popup event.
00736     ev = f_window->AwaitEvent(SVET_ANY);
00737     ev_type = ev->type;
00738     if (ev_type == SVET_CLICK) {
00739       int feature_index = feature_space.XYToFeatureIndex(ev->x, ev->y);
00740       if (feature_index >= 0) {
00741         // Iterate samples and display those with the feature.
00742         Shape shape;
00743         shape.AddToShape(class_id1, cloud_font);
00744         s_window->Clear();
00745         samples_.DisplaySamplesWithFeature(feature_index, shape,
00746                                            feature_space, ScrollView::GREEN,
00747                                            s_window);
00748         s_window->Update();
00749       }
00750     }
00751     delete ev;
00752   } while (ev_type != SVET_DESTROY);
00753 }
00754 #endif  // GRAPHICS_DISABLED
00755 
00756 void MasterTrainer::TestClassifierVOld(bool replicate_samples,
00757                                        ShapeClassifier* test_classifier,
00758                                        ShapeClassifier* old_classifier) {
00759   SampleIterator sample_it;
00760   sample_it.Init(NULL, NULL, replicate_samples, &samples_);
00761   ErrorCounter::DebugNewErrors(test_classifier, old_classifier,
00762                                CT_UNICHAR_TOPN_ERR, fontinfo_table_,
00763                                page_images_, &sample_it);
00764 }
00765 
00766 // Tests the given test_classifier on the internal samples.
00767 // See TestClassifier for details.
00768 void MasterTrainer::TestClassifierOnSamples(CountTypes error_mode,
00769                                             int report_level,
00770                                             bool replicate_samples,
00771                                             ShapeClassifier* test_classifier,
00772                                             STRING* report_string) {
00773   TestClassifier(error_mode, report_level, replicate_samples, &samples_,
00774                  test_classifier, report_string);
00775 }
00776 
00777 // Tests the given test_classifier on the given samples.
00778 // error_mode indicates what counts as an error.
00779 // report_levels:
00780 // 0 = no output.
00781 // 1 = bottom-line error rate.
00782 // 2 = bottom-line error rate + time.
00783 // 3 = font-level error rate + time.
00784 // 4 = list of all errors + short classifier debug output on 16 errors.
00785 // 5 = list of all errors + short classifier debug output on 25 errors.
00786 // If replicate_samples is true, then the test is run on an extended test
00787 // sample including replicated and systematically perturbed samples.
00788 // If report_string is non-NULL, a summary of the results for each font
00789 // is appended to the report_string.
00790 double MasterTrainer::TestClassifier(CountTypes error_mode,
00791                                      int report_level,
00792                                      bool replicate_samples,
00793                                      TrainingSampleSet* samples,
00794                                      ShapeClassifier* test_classifier,
00795                                      STRING* report_string) {
00796   SampleIterator sample_it;
00797   sample_it.Init(NULL, NULL, replicate_samples, samples);
00798   if (report_level > 0) {
00799     int num_samples = 0;
00800     for (sample_it.Begin(); !sample_it.AtEnd(); sample_it.Next())
00801       ++num_samples;
00802     tprintf("Iterator has charset size of %d/%d, %d shapes, %d samples\n",
00803             sample_it.SparseCharsetSize(), sample_it.CompactCharsetSize(),
00804             test_classifier->GetShapeTable()->NumShapes(), num_samples);
00805     tprintf("Testing %sREPLICATED:\n", replicate_samples ? "" : "NON-");
00806   }
00807   double unichar_error = 0.0;
00808   ErrorCounter::ComputeErrorRate(test_classifier, report_level,
00809                                  error_mode, fontinfo_table_,
00810                                  page_images_, &sample_it, &unichar_error,
00811                                  NULL, report_string);
00812   return unichar_error;
00813 }
00814 
00815 // Returns the average (in some sense) distance between the two given
00816 // shapes, which may contain multiple fonts and/or unichars.
00817 float MasterTrainer::ShapeDistance(const ShapeTable& shapes, int s1, int s2) {
00818   const IntFeatureMap& feature_map = feature_map_;
00819   const Shape& shape1 = shapes.GetShape(s1);
00820   const Shape& shape2 = shapes.GetShape(s2);
00821   int num_chars1 = shape1.size();
00822   int num_chars2 = shape2.size();
00823   float dist_sum = 0.0f;
00824   int dist_count = 0;
00825   if (num_chars1 > 1 || num_chars2 > 1) {
00826     // In the multi-char case try to optimize the calculation by computing
00827     // distances between characters of matching font where possible.
00828     for (int c1 = 0; c1 < num_chars1; ++c1) {
00829       for (int c2 = 0; c2 < num_chars2; ++c2) {
00830         dist_sum += samples_.UnicharDistance(shape1[c1], shape2[c2],
00831                                              true, feature_map);
00832         ++dist_count;
00833       }
00834     }
00835   } else {
00836     // In the single unichar case, there is little alternative, but to compute
00837     // the squared-order distance between pairs of fonts.
00838     dist_sum = samples_.UnicharDistance(shape1[0], shape2[0],
00839                                         false, feature_map);
00840     ++dist_count;
00841   }
00842   return dist_sum / dist_count;
00843 }
00844 
00845 // Replaces samples that are always fragmented with the corresponding
00846 // fragment samples.
00847 void MasterTrainer::ReplaceFragmentedSamples() {
00848   if (fragments_ == NULL) return;
00849   // Remove samples that are replaced by fragments. Each class that was
00850   // always naturally fragmented should be replaced by its fragments.
00851   int num_samples = samples_.num_samples();
00852   for (int s = 0; s < num_samples; ++s) {
00853     TrainingSample* sample = samples_.mutable_sample(s);
00854     if (fragments_[sample->class_id()] > 0)
00855       samples_.KillSample(sample);
00856   }
00857   samples_.DeleteDeadSamples();
00858 
00859   // Get ids of fragments in junk_samples_ that replace the dead chars.
00860   const UNICHARSET& frag_set = junk_samples_.unicharset();
00861 #if 0
00862   // TODO(rays) The original idea was to replace only graphemes that were
00863   // always naturally fragmented, but that left a lot of the Indic graphemes
00864   // out. Determine whether we can go back to that idea now that spacing
00865   // is fixed in the training images, or whether this code is obsolete.
00866   bool* good_junk = new bool[frag_set.size()];
00867   memset(good_junk, 0, sizeof(*good_junk) * frag_set.size());
00868   for (int dead_ch = 1; dead_ch < unicharset_.size(); ++dead_ch) {
00869     int frag_ch = fragments_[dead_ch];
00870     if (frag_ch <= 0) continue;
00871     const char* frag_utf8 = frag_set.id_to_unichar(frag_ch);
00872     CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(frag_utf8);
00873     // Mark the chars for all parts of the fragment as good in good_junk.
00874     for (int part = 0; part < frag->get_total(); ++part) {
00875       frag->set_pos(part);
00876       int good_ch = frag_set.unichar_to_id(frag->to_string().string());
00877       if (good_ch != INVALID_UNICHAR_ID)
00878         good_junk[good_ch] = true;  // We want this one.
00879     }
00880   }
00881 #endif
00882   // For now just use all the junk that was from natural fragments.
00883   // Get samples of fragments in junk_samples_ that replace the dead chars.
00884   int num_junks = junk_samples_.num_samples();
00885   for (int s = 0; s < num_junks; ++s) {
00886     TrainingSample* sample = junk_samples_.mutable_sample(s);
00887     int junk_id = sample->class_id();
00888     const char* frag_utf8 = frag_set.id_to_unichar(junk_id);
00889     CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(frag_utf8);
00890     if (frag != NULL && frag->is_natural()) {
00891       junk_samples_.extract_sample(s);
00892       samples_.AddSample(frag_set.id_to_unichar(junk_id), sample);
00893     }
00894   }
00895   junk_samples_.DeleteDeadSamples();
00896   junk_samples_.OrganizeByFontAndClass();
00897   samples_.OrganizeByFontAndClass();
00898   unicharset_.clear();
00899   unicharset_.AppendOtherUnicharset(samples_.unicharset());
00900   // delete [] good_junk;
00901   // Fragments_ no longer needed?
00902   delete [] fragments_;
00903   fragments_ = NULL;
00904 }
00905 
00906 // Runs a hierarchical agglomerative clustering to merge shapes in the given
00907 // shape_table, while satisfying the given constraints:
00908 // * End with at least min_shapes left in shape_table,
00909 // * No shape shall have more than max_shape_unichars in it,
00910 // * Don't merge shapes where the distance between them exceeds max_dist.
00911 const float kInfiniteDist = 999.0f;
00912 void MasterTrainer::ClusterShapes(int min_shapes,  int max_shape_unichars,
00913                                   float max_dist, ShapeTable* shapes) {
00914   int num_shapes = shapes->NumShapes();
00915   int max_merges = num_shapes - min_shapes;
00916   GenericVector<ShapeDist>* shape_dists =
00917       new GenericVector<ShapeDist>[num_shapes];
00918   float min_dist = kInfiniteDist;
00919   int min_s1 = 0;
00920   int min_s2 = 0;
00921   tprintf("Computing shape distances...");
00922   for (int s1 = 0; s1 < num_shapes; ++s1) {
00923     for (int s2 = s1 + 1; s2 < num_shapes; ++s2) {
00924       ShapeDist dist(s1, s2, ShapeDistance(*shapes, s1, s2));
00925       shape_dists[s1].push_back(dist);
00926       if (dist.distance < min_dist) {
00927         min_dist = dist.distance;
00928         min_s1 = s1;
00929         min_s2 = s2;
00930       }
00931     }
00932     tprintf(" %d", s1);
00933   }
00934   tprintf("\n");
00935   int num_merged = 0;
00936   while (num_merged < max_merges && min_dist < max_dist) {
00937     tprintf("Distance = %f: ", min_dist);
00938     int num_unichars = shapes->MergedUnicharCount(min_s1, min_s2);
00939     shape_dists[min_s1][min_s2 - min_s1 - 1].distance = kInfiniteDist;
00940     if (num_unichars > max_shape_unichars) {
00941       tprintf("Merge of %d and %d with %d would exceed max of %d unichars\n",
00942               min_s1, min_s2, num_unichars, max_shape_unichars);
00943     } else {
00944       shapes->MergeShapes(min_s1, min_s2);
00945       shape_dists[min_s2].clear();
00946       ++num_merged;
00947 
00948       for (int s = 0; s < min_s1; ++s) {
00949         if (!shape_dists[s].empty()) {
00950           shape_dists[s][min_s1 - s - 1].distance =
00951               ShapeDistance(*shapes, s, min_s1);
00952           shape_dists[s][min_s2 - s -1].distance = kInfiniteDist;
00953         }
00954       }
00955       for (int s2 = min_s1 + 1; s2 < num_shapes; ++s2) {
00956         if (shape_dists[min_s1][s2 - min_s1 - 1].distance < kInfiniteDist)
00957           shape_dists[min_s1][s2 - min_s1 - 1].distance =
00958               ShapeDistance(*shapes, min_s1, s2);
00959       }
00960       for (int s = min_s1 + 1; s < min_s2; ++s) {
00961         if (!shape_dists[s].empty()) {
00962           shape_dists[s][min_s2 - s - 1].distance = kInfiniteDist;
00963         }
00964       }
00965     }
00966     min_dist = kInfiniteDist;
00967     for (int s1 = 0; s1 < num_shapes; ++s1) {
00968       for (int i = 0; i < shape_dists[s1].size(); ++i) {
00969         if (shape_dists[s1][i].distance < min_dist) {
00970           min_dist = shape_dists[s1][i].distance;
00971           min_s1 = s1;
00972           min_s2 = s1 + 1 + i;
00973         }
00974       }
00975     }
00976   }
00977   tprintf("Stopped with %d merged, min dist %f\n", num_merged, min_dist);
00978   delete [] shape_dists;
00979   if (debug_level_ > 1) {
00980     for (int s1 = 0; s1 < num_shapes; ++s1) {
00981       if (shapes->MasterDestinationIndex(s1) == s1) {
00982         tprintf("Master shape:%s\n", shapes->DebugStr(s1).string());
00983       }
00984     }
00985   }
00986 }
00987 
00988 
00989 }  // namespace tesseract.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines