// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H__
#define DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H__

#include "structural_object_detection_trainer_abstract.h"
#include "../algs.h"
#include "../optimization.h"
#include "structural_svm_object_detection_problem.h"
#include "../image_processing/object_detector.h"
#include "../image_processing/box_overlap_testing.h"


namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename image_scanner_type,
        typename overlap_tester_type = test_box_overlap
        >
    class structural_object_detection_trainer : noncopyable
    {

    public:
        typedef double scalar_type;
        typedef default_memory_manager mem_manager_type;
        typedef object_detector<image_scanner_type,overlap_tester_type> trained_function_type;


        explicit structural_object_detection_trainer (
            const image_scanner_type& scanner_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(scanner_.get_num_detection_templates() > 0,
                "\t structural_object_detection_trainer::structural_object_detection_trainer(scanner_)"
                << "\n\t You can't have zero detection templates"
                << "\n\t this: " << this
                );

            C = 1;
            verbose = false;
            eps = 0.3;
            num_threads = 2;
            max_cache_size = 40;
            match_eps = 0.5;
            loss_per_missed_target = 1;
            loss_per_false_alarm = 1;

            scanner.copy_configuration(scanner_);

            auto_overlap_tester = is_same_type<overlap_tester_type,test_box_overlap>::value;
        }

        bool auto_set_overlap_tester (
        ) const 
        { 
            return auto_overlap_tester; 
        }

        void set_overlap_tester (
            const overlap_tester_type& tester
        )
        {
            overlap_tester = tester;
            auto_overlap_tester = false;
        }

        overlap_tester_type get_overlap_tester (
        ) const
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(auto_set_overlap_tester() == false,
                "\t overlap_tester_type structural_object_detection_trainer::get_overlap_tester()"
                << "\n\t You can't call this function if the overlap tester is generated dynamically."
                << "\n\t this: " << this
                );

            return overlap_tester;
        }

        void set_num_threads (
            unsigned long num
        )
        {
            num_threads = num;
        }

        unsigned long get_num_threads (
        ) const
        {
            return num_threads;
        }

        void set_epsilon (
            scalar_type eps_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(eps_ > 0,
                "\t void structural_object_detection_trainer::set_epsilon()"
                << "\n\t eps_ must be greater than 0"
                << "\n\t eps_: " << eps_ 
                << "\n\t this: " << this
                );

            eps = eps_;
        }

        scalar_type get_epsilon (
        ) const { return eps; }

        void set_max_cache_size (
            unsigned long max_size
        )
        {
            max_cache_size = max_size;
        }

        unsigned long get_max_cache_size (
        ) const
        {
            return max_cache_size; 
        }

        void be_verbose (
        )
        {
            verbose = true;
        }

        void be_quiet (
        )
        {
            verbose = false;
        }

        void set_oca (
            const oca& item
        )
        {
            solver = item;
        }

        const oca get_oca (
        ) const
        {
            return solver;
        }

        void set_c (
            scalar_type C_ 
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(C_ > 0,
                "\t void structural_object_detection_trainer::set_c()"
                << "\n\t C_ must be greater than 0"
                << "\n\t C_:    " << C_ 
                << "\n\t this: " << this
                );

            C = C_;
        }

        scalar_type get_c (
        ) const
        {
            return C;
        }

        void set_match_eps (
            double eps
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(0 < eps && eps < 1, 
                "\t void structural_object_detection_trainer::set_match_eps(eps)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t eps:  " << eps 
                << "\n\t this: " << this
                );

            match_eps = eps;
        }

        double get_match_eps (
        ) const
        {
            return match_eps;
        }

        double get_loss_per_missed_target (
        ) const
        {
            return loss_per_missed_target;
        }

        void set_loss_per_missed_target (
            double loss
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(loss > 0, 
                "\t void structural_object_detection_trainer::set_loss_per_missed_target(loss)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t loss: " << loss
                << "\n\t this: " << this
                );

            loss_per_missed_target = loss;
        }

        double get_loss_per_false_alarm (
        ) const
        {
            return loss_per_false_alarm;
        }

        void set_loss_per_false_alarm (
            double loss
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(loss > 0, 
                "\t void structural_object_detection_trainer::set_loss_per_false_alarm(loss)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t loss: " << loss
                << "\n\t this: " << this
                );

            loss_per_false_alarm = loss;
        }

        template <
            typename image_array_type
            >
        const trained_function_type train (
            const image_array_type& images,
            const std::vector<std::vector<rectangle> >& truth_rects
        ) const
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(is_learning_problem(images,truth_rects) == true,
                "\t trained_function_type structural_object_detection_trainer::train(x,y)"
                << "\n\t invalid inputs were given to this function"
                << "\n\t images.size():      " << images.size()
                << "\n\t truth_rects.size(): " << truth_rects.size()
                << "\n\t is_learning_problem(images,truth_rects): " << is_learning_problem(images,truth_rects)
                );

            overlap_tester_type local_overlap_tester;

            if (auto_overlap_tester)
            {
                std::vector<std::vector<rectangle> > mapped_rects(truth_rects.size());
                for (unsigned long i = 0; i < truth_rects.size(); ++i)
                {
                    mapped_rects[i].resize(truth_rects[i].size());
                    for (unsigned long j = 0; j < truth_rects[i].size(); ++j)
                    {
                        mapped_rects[i][j] = scanner.get_best_matching_rect(truth_rects[i][j]);
                    }
                }

                local_overlap_tester = find_tight_overlap_tester(mapped_rects);
            }
            else
            {
                local_overlap_tester = overlap_tester;
            }

            structural_svm_object_detection_problem<image_scanner_type,overlap_tester_type,image_array_type > 
                svm_prob(scanner, local_overlap_tester, images, truth_rects, num_threads);

            if (verbose)
                svm_prob.be_verbose();

            svm_prob.set_c(C);
            svm_prob.set_epsilon(eps);
            svm_prob.set_max_cache_size(max_cache_size);
            svm_prob.set_match_eps(match_eps);
            svm_prob.set_loss_per_missed_target(loss_per_missed_target);
            svm_prob.set_loss_per_false_alarm(loss_per_false_alarm);
            matrix<double,0,1> w;

            // Run the optimizer to find the optimal w.
            solver(svm_prob,w);

            // report the results of the training.
            return object_detector<image_scanner_type,overlap_tester_type>(scanner, local_overlap_tester, w);
        }


    private:

        image_scanner_type scanner;
        overlap_tester_type overlap_tester;

        double C;
        oca solver;
        double eps;
        double match_eps;
        bool verbose;
        unsigned long num_threads;
        unsigned long max_cache_size;
        double loss_per_missed_target;
        double loss_per_false_alarm;
        bool auto_overlap_tester;

    }; 

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H__