|
tesseract 3.04.01
|
00001 /********************************************************************** 00002 * File: word_unigrams.cpp 00003 * Description: Implementation of the Word Unigrams Class 00004 * Author: Ahmad Abdulkader 00005 * Created: 2008 00006 * 00007 * (C) Copyright 2008, Google Inc. 00008 ** Licensed under the Apache License, Version 2.0 (the "License"); 00009 ** you may not use this file except in compliance with the License. 00010 ** You may obtain a copy of the License at 00011 ** http://www.apache.org/licenses/LICENSE-2.0 00012 ** Unless required by applicable law or agreed to in writing, software 00013 ** distributed under the License is distributed on an "AS IS" BASIS, 00014 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00015 ** See the License for the specific language governing permissions and 00016 ** limitations under the License. 00017 * 00018 **********************************************************************/ 00019 00020 #include <math.h> 00021 #include <string> 00022 #include <vector> 00023 #include <algorithm> 00024 00025 #include "const.h" 00026 #include "cube_utils.h" 00027 #include "ndminx.h" 00028 #include "word_unigrams.h" 00029 00030 namespace tesseract { 00031 00032 WordUnigrams::WordUnigrams() { 00033 costs_ = NULL; 00034 words_ = NULL; 00035 word_cnt_ = 0; 00036 } 00037 00038 WordUnigrams::~WordUnigrams() { 00039 if (words_ != NULL) { 00040 if (words_[0] != NULL) { 00041 delete []words_[0]; 00042 } 00043 00044 delete []words_; 00045 words_ = NULL; 00046 } 00047 00048 if (costs_ != NULL) { 00049 delete []costs_; 00050 } 00051 } 00052 00057 WordUnigrams *WordUnigrams::Create(const string &data_file_path, 00058 const string &lang) { 00059 string file_name; 00060 string str; 00061 00062 file_name = data_file_path + lang; 00063 file_name += ".cube.word-freq"; 00064 00065 // load the string into memory 00066 if (CubeUtils::ReadFileToString(file_name, &str) == false) { 00067 return NULL; 00068 } 00069 00070 // split into lines 00071 vector<string> str_vec; 00072 CubeUtils::SplitStringUsing(str, "\r\n \t", &str_vec); 00073 if (str_vec.size() < 2) { 00074 return NULL; 00075 } 00076 00077 // allocate memory 00078 WordUnigrams *word_unigrams_obj = new WordUnigrams(); 00079 if (word_unigrams_obj == NULL) { 00080 fprintf(stderr, "Cube ERROR (WordUnigrams::Create): could not create " 00081 "word unigrams object.\n"); 00082 return NULL; 00083 } 00084 00085 int full_len = str.length(); 00086 int word_cnt = str_vec.size() / 2; 00087 word_unigrams_obj->words_ = new char*[word_cnt]; 00088 word_unigrams_obj->costs_ = new int[word_cnt]; 00089 00090 if (word_unigrams_obj->words_ == NULL || 00091 word_unigrams_obj->costs_ == NULL) { 00092 fprintf(stderr, "Cube ERROR (WordUnigrams::Create): error allocating " 00093 "word unigram fields.\n"); 00094 delete word_unigrams_obj; 00095 return NULL; 00096 } 00097 00098 word_unigrams_obj->words_[0] = new char[full_len]; 00099 if (word_unigrams_obj->words_[0] == NULL) { 00100 fprintf(stderr, "Cube ERROR (WordUnigrams::Create): error allocating " 00101 "word unigram fields.\n"); 00102 delete word_unigrams_obj; 00103 return NULL; 00104 } 00105 00106 // construct sorted list of words and costs 00107 word_unigrams_obj->word_cnt_ = 0; 00108 char *char_buff = word_unigrams_obj->words_[0]; 00109 word_cnt = 0; 00110 int max_cost = 0; 00111 00112 for (int wrd = 0; wrd < str_vec.size(); wrd += 2) { 00113 word_unigrams_obj->words_[word_cnt] = char_buff; 00114 00115 strcpy(char_buff, str_vec[wrd].c_str()); 00116 char_buff += (str_vec[wrd].length() + 1); 00117 00118 if (sscanf(str_vec[wrd + 1].c_str(), "%d", 00119 word_unigrams_obj->costs_ + word_cnt) != 1) { 00120 fprintf(stderr, "Cube ERROR (WordUnigrams::Create): error reading " 00121 "word unigram data.\n"); 00122 delete word_unigrams_obj; 00123 return NULL; 00124 } 00125 // update max cost 00126 max_cost = MAX(max_cost, word_unigrams_obj->costs_[word_cnt]); 00127 word_cnt++; 00128 } 00129 word_unigrams_obj->word_cnt_ = word_cnt; 00130 00131 // compute the not-in-list-cost by assuming that a word not in the list 00132 // [ahmadab]: This can be computed as follows: 00133 // - Given that the distribution of words follow Zipf's law: 00134 // (F = K / (rank ^ S)), where s is slightly > 1.0 00135 // - Number of words in the list is N 00136 // - The mean frequency of a word that did not appear in the list is the 00137 // area under the rest of the Zipf's curve divided by 2 (the mean) 00138 // - The area would be the bound integral from N to infinity = 00139 // (K * S) / (N ^ (S + 1)) ~= K / (N ^ 2) 00140 // - Given that cost = -LOG(prob), the cost of an unlisted word would be 00141 // = max_cost + 2*LOG(N) 00142 word_unigrams_obj->not_in_list_cost_ = max_cost + 00143 (2 * CubeUtils::Prob2Cost(1.0 / word_cnt)); 00144 // success 00145 return word_unigrams_obj; 00146 } 00147 00154 int WordUnigrams::Cost(const char_32 *key_str32, 00155 LangModel *lang_mod, 00156 CharSet *char_set) const { 00157 if (!key_str32) 00158 return 0; 00159 // convert string to UTF8 to split into space-separated words 00160 string key_str; 00161 CubeUtils::UTF32ToUTF8(key_str32, &key_str); 00162 vector<string> words; 00163 CubeUtils::SplitStringUsing(key_str, " \t", &words); 00164 00165 // no words => no cost 00166 if (words.size() <= 0) { 00167 return 0; 00168 } 00169 00170 // aggregate the costs of all the words 00171 int cost = 0; 00172 for (int word_idx = 0; word_idx < words.size(); word_idx++) { 00173 // convert each word back to UTF32 for analyzing case and punctuation 00174 string_32 str32; 00175 CubeUtils::UTF8ToUTF32(words[word_idx].c_str(), &str32); 00176 int len = CubeUtils::StrLen(str32.c_str()); 00177 00178 // strip all trailing punctuation 00179 string clean_str; 00180 int clean_len = len; 00181 bool trunc = false; 00182 while (clean_len > 0 && 00183 lang_mod->IsTrailingPunc(str32.c_str()[clean_len - 1])) { 00184 --clean_len; 00185 trunc = true; 00186 } 00187 00188 // If either the original string was not truncated (no trailing 00189 // punctuation) or the entire string was removed (all characters 00190 // are trailing punctuation), evaluate original word as is; 00191 // otherwise, copy all but the trailing punctuation characters 00192 char_32 *clean_str32 = NULL; 00193 if (clean_len == 0 || !trunc) { 00194 clean_str32 = CubeUtils::StrDup(str32.c_str()); 00195 } else { 00196 clean_str32 = new char_32[clean_len + 1]; 00197 for (int i = 0; i < clean_len; ++i) { 00198 clean_str32[i] = str32[i]; 00199 } 00200 clean_str32[clean_len] = '\0'; 00201 } 00202 ASSERT_HOST(clean_str32 != NULL); 00203 00204 string str8; 00205 CubeUtils::UTF32ToUTF8(clean_str32, &str8); 00206 int word_cost = CostInternal(str8.c_str()); 00207 00208 // if case invariant, get costs of all-upper-case and all-lower-case 00209 // versions and return the min cost 00210 if (clean_len >= kMinLengthNumOrCaseInvariant && 00211 CubeUtils::IsCaseInvariant(clean_str32, char_set)) { 00212 char_32 *lower_32 = CubeUtils::ToLower(clean_str32, char_set); 00213 if (lower_32) { 00214 string lower_8; 00215 CubeUtils::UTF32ToUTF8(lower_32, &lower_8); 00216 word_cost = MIN(word_cost, CostInternal(lower_8.c_str())); 00217 delete [] lower_32; 00218 } 00219 char_32 *upper_32 = CubeUtils::ToUpper(clean_str32, char_set); 00220 if (upper_32) { 00221 string upper_8; 00222 CubeUtils::UTF32ToUTF8(upper_32, &upper_8); 00223 word_cost = MIN(word_cost, CostInternal(upper_8.c_str())); 00224 delete [] upper_32; 00225 } 00226 } 00227 00228 if (clean_len >= kMinLengthNumOrCaseInvariant) { 00229 // if characters are all numeric, incur 0 word cost 00230 bool is_numeric = true; 00231 for (int i = 0; i < clean_len; ++i) { 00232 if (!lang_mod->IsDigit(clean_str32[i])) 00233 is_numeric = false; 00234 } 00235 if (is_numeric) 00236 word_cost = 0; 00237 } 00238 delete [] clean_str32; 00239 cost += word_cost; 00240 } // word_idx 00241 00242 // return the mean cost 00243 return static_cast<int>(cost / static_cast<double>(words.size())); 00244 } 00245 00249 int WordUnigrams::CostInternal(const char *key_str) const { 00250 if (strlen(key_str) == 0) 00251 return not_in_list_cost_; 00252 int hi = word_cnt_ - 1; 00253 int lo = 0; 00254 while (lo <= hi) { 00255 int current = (hi + lo) / 2; 00256 int comp = strcmp(key_str, words_[current]); 00257 // a match 00258 if (comp == 0) { 00259 return costs_[current]; 00260 } 00261 if (comp < 0) { 00262 // go lower 00263 hi = current - 1; 00264 } else { 00265 // go higher 00266 lo = current + 1; 00267 } 00268 } 00269 return not_in_list_cost_; 00270 } 00271 } // namespace tesseract