74 tesseract::CheckSharedLibraryVersion();
76 if (FLAGS_model_output.empty()) {
77 tprintf(
"Must provide a --model_output!\n");
80 if (FLAGS_traineddata.empty()) {
81 tprintf(
"Must provide a --traineddata see training wiki\n");
87 test_file +=
"_wtest";
88 FILE* f = fopen(test_file.
c_str(),
"wb");
91 if (
remove(test_file.
c_str()) != 0) {
92 tprintf(
"Error, failed to remove %s: %s\n",
93 test_file.
c_str(), strerror(errno));
97 tprintf(
"Error, model output cannot be written: %s\n", strerror(errno));
102 STRING checkpoint_file = FLAGS_model_output.
c_str();
103 checkpoint_file +=
"_checkpoint";
104 STRING checkpoint_bak = checkpoint_file +
".bak";
106 nullptr,
nullptr,
nullptr,
nullptr, FLAGS_model_output.c_str(),
107 checkpoint_file.
c_str(), FLAGS_debug_interval,
108 static_cast<int64_t
>(FLAGS_max_image_MB) * 1048576);
109 trainer.InitCharSet(FLAGS_traineddata.c_str());
113 if (FLAGS_stop_training || FLAGS_debug_network) {
114 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
nullptr)) {
115 tprintf(
"Failed to read continue from: %s\n",
116 FLAGS_continue_from.c_str());
119 if (FLAGS_debug_network) {
120 trainer.DebugNetwork();
122 if (FLAGS_convert_to_int) trainer.ConvertToInt();
123 if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
124 tprintf(
"Failed to write recognition model : %s\n",
125 FLAGS_model_output.c_str());
132 if (FLAGS_train_listfile.empty()) {
133 tprintf(
"Must supply a list of training filenames! --train_listfile\n");
139 tprintf(
"Failed to load list of training filenames from %s\n",
140 FLAGS_train_listfile.c_str());
145 if (trainer.TryLoadingCheckpoint(checkpoint_file.
string(),
nullptr) ||
146 trainer.TryLoadingCheckpoint(checkpoint_bak.
string(),
nullptr)) {
147 tprintf(
"Successfully restored trainer from %s\n",
148 checkpoint_file.
string());
150 if (!FLAGS_continue_from.empty()) {
152 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
153 FLAGS_append_index >= 0
154 ? FLAGS_continue_from.c_str()
155 : FLAGS_old_traineddata.c_str())) {
156 tprintf(
"Failed to continue from: %s\n", FLAGS_continue_from.c_str());
159 tprintf(
"Continuing from %s\n", FLAGS_continue_from.c_str());
160 trainer.InitIterations();
162 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
163 if (FLAGS_append_index >= 0) {
164 tprintf(
"Appending a new network to an old one!!");
165 if (FLAGS_continue_from.empty()) {
166 tprintf(
"Must set --continue_from for appending!\n");
171 if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
172 FLAGS_net_mode, FLAGS_weight_range,
173 FLAGS_learning_rate, FLAGS_momentum,
175 tprintf(
"Failed to create network from spec: %s\n",
176 FLAGS_net_spec.c_str());
179 trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
182 if (!trainer.LoadAllTrainingData(filenames,
183 FLAGS_sequential_training
186 FLAGS_randomly_rotate)) {
187 tprintf(
"Load of images failed!!\n");
194 if (!FLAGS_eval_listfile.empty()) {
195 if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
196 tprintf(
"Failed to load eval data from: %s\n",
197 FLAGS_eval_listfile.c_str());
205 int iteration = trainer.training_iteration();
207 iteration < target_iteration &&
208 (iteration < FLAGS_max_iterations || FLAGS_max_iterations == 0);
209 iteration = trainer.training_iteration()) {
210 trainer.TrainOnLine(&trainer,
false);
213 trainer.MaintainCheckpoints(tester_callback, &log_str);
215 }
while (trainer.best_error_rate() > FLAGS_target_error_rate &&
216 (trainer.training_iteration() < FLAGS_max_iterations ||
217 FLAGS_max_iterations == 0));
218 delete tester_callback;
219 tprintf(
"Finished! Error rate = %g\n", trainer.best_error_rate());
bool LoadFileLinesToStrings(const STRING &filename, GenericVector< STRING > *lines)
const int kNumPagesPerBatch
STRING RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)
const char * string() const
DLLSYM void tprintf(const char *format,...)
void ParseArguments(int *argc, char ***argv)
_ConstTessMemberResultCallback_5_0< false, R, T1, P1, P2, P3, P4, P5 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)(P1, P2, P3, P4, P5) const, typename Identity< P1 >::type p1, typename Identity< P2 >::type p2, typename Identity< P3 >::type p3, typename Identity< P4 >::type p4, typename Identity< P5 >::type p5)
const char * c_str() const