#!/usr/bin/python
# -*- encoding: utf-8; py-indent-offset: 4 -*-
# +------------------------------------------------------------------+
# |             ____ _               _        __  __ _  __           |
# |            / ___| |__   ___  ___| | __   |  \/  | |/ /           |
# |           | |   | '_ \ / _ \/ __| |/ /   | |\/| | ' /            |
# |           | |___| | | |  __/ (__|   <    | |  | | . \            |
# |            \____|_| |_|\___|\___|_|\_\___|_|  |_|_|\_\           |
# |                                                                  |
# | Copyright Mathias Kettner 2014             mk@mathias-kettner.de |
# +------------------------------------------------------------------+
#
# This file is part of Check_MK.
# The official homepage is at http://mathias-kettner.de/check_mk.
#
# check_mk is free software;  you can redistribute it and/or modify it
# under the  terms of the  GNU General Public License  as published by
# the Free Software Foundation in version 2.  check_mk is  distributed
# in the hope that it will be useful, but WITHOUT ANY WARRANTY;  with-
# out even the implied warranty of  MERCHANTABILITY  or  FITNESS FOR A
# PARTICULAR PURPOSE. See the  GNU General Public License for more de-
# tails. You should have  received  a copy of the  GNU  General Public
# License along with GNU Make; see the file  COPYING.  If  not,  write
# to the Free Software Foundation, Inc., 51 Franklin St,  Fifth Floor,
# Boston, MA 02110-1301 USA.

# This check performs HTTP requests with some advanced features like
# a) Detecting, populating and submitting HTML forms
# b) Accepts and uses cookies
# c) Follos HTTP redirects
# d) Extends HTTP headers

# Example calls:
#
# Call the page test.php, find the single form on that page and
# submit it with default values:
#
#  ./check_form_submit -I localhost -u /test.php
#
# Same as above, but expect the string "Hello" in the response
# of the form:
#
# ./check_form_submit -I localhost -u /test.php -e "Hello"
#
# Login as omdadmin with password omd, in the OMD site named /event,
# let the login redirect to the wato.py and expect the string WATO
# in this response:
#
# ./check_form_submit -I localhost -u /event -q '_origtarget=wato.py&_username=omdadmin&_password=omd' -e 'WATO'

import cookielib
import getopt
import re
import socket
import sys
import urllib
import urllib2
from HTMLParser import HTMLParser


def usage():
    sys.stderr.write("""
USAGE: check_form_submit -I <HOSTADDRESS> [-u <URI>] [-p <PORT>] [-s]
            [-f <FORMNAME>] [-q <QUERYPARAMS>] [-e <REGEX>] [-t <TIMEOUT> ]
            [-n <WARN>,<CRIT>]

OPTIONS:
  -I HOSTADDRESS                The IP address or hostname of the host to contact.
                                This option can be given multiple times to contact
                                several hosts with the same query.
  -u URI                        The URL string to query, "/" by default
  -p PORT                       TCP Port to communicate with
  -s                            Encrypt the connection using SSL
  -f FORMNAME                   Name of the form to fill, must match with the
                                contents of the "name" attribute
  -q QUERYPARAMS                Keys/Values of form fields to be popuplated
  -e REGEX                      Expected regular expression in the HTML response
                                after form submission
  -t TIMEOUT                    HTTP connect timeout in seconds
  -n WARN,CRIT                  Is only used in "multiple" mode. Number of successful
                                responses to result in a WARN and/or CRIT state
  -d                            Enable debug mode
  -h, --help                    Show this help message and exit

""")


def debug(msg):
    if opt_debug:
        sys.stderr.write('%s\n' % msg)


class HostResult(Exception):
    def __init__(self, result):
        super(HostResult, self).__init__()
        self.result = result


def new_state(rc, s):
    raise HostResult((rc, s))


def bail_out(rc, s):
    stxt = ['OK', 'WARN', 'CRIT', 'UNKN'][rc]
    sys.stdout.write('%s - %s\n' % (stxt, s))
    sys.exit(rc)


def get_base_url(ssl, host, port):
    if not ssl and port == 443:
        ssl = True

    proto = 'https' if ssl else 'http'
    if (proto == 'http' and port == 80) or (proto == 'https' and port == 443):
        portspec = ''
    else:
        portspec = ':%d' % port

    return '%s://%s%s' % (proto, host, portspec)


# TODO: Refactor to requests
def init_http():
    return urllib2.build_opener(urllib2.HTTPRedirectHandler(), urllib2.HTTPHandler(debuglevel=0),
                                urllib2.HTTPSHandler(debuglevel=0),
                                urllib2.HTTPCookieProcessor(cookielib.CookieJar()))


def open_url(client, url, method='GET', data=None, timeout=None):
    if method == 'GET' and data is not None:
        # Add the query string to the url in this case
        start = '&' if '?' in url else '?'
        url += start + urllib.urlencode(data)
        data = None

    try:
        if data:
            debug('POST %s' % url)
            data = urllib.urlencode(data.items())
            debug('  => %s' % data)
            fd = client.open(url, data, timeout)  # will be a POST
        else:
            debug('GET %s' % url)
            fd = client.open(url, timeout=timeout)  # GET
    except urllib2.HTTPError, e:
        new_state(2, 'Unable to open %s: [%d] %s' % (url, e.code, e))
    except urllib2.URLError, e:
        new_state(2, 'Unable to open %s: %s' % (url, e.reason))
    except socket.timeout, e:
        new_state(2, 'Unable to open %s: %s' % (url, e))
    real_url = fd.geturl()
    code = fd.getcode()
    content = fd.read()

    encoding = fd.headers.getparam('charset')
    if encoding:
        content = content.decode(encoding)

    debug('CODE: %s RESPONSE:' % code)
    debug('\n'.join(
        ['    %02d %s' % (index + 1, line) for index, line in enumerate(content.split('\n'))]))
    return code, real_url, content


class FormParser(HTMLParser):
    def __init__(self):
        self.forms = {}
        self.current_form = None
        HTMLParser.__init__(self)

    def handle_starttag(self, tag, attrs):
        attrs = dict(attrs)

        if tag == 'form':
            name = attrs.get('name', 'unnamed-%d' % (len(self.forms) + 1))
            self.forms[name] = {
                'attrs': dict(attrs),
                'elements': {},
            }
            self.current_form = self.forms[name]
        elif tag == 'input':
            if self.current_form is None:
                debug('Ignoring form field out of form tag')
            else:
                if 'name' in attrs:
                    self.current_form['elements'][attrs['name']] = attrs.get('value', '')
                else:
                    debug('Ignoring form field without name %r' % attrs)

    def handle_endtag(self, tag):
        if tag == 'form':
            self.current_form = None


# Parse XML to find all form elements
# One form found and no form_name given, use that one
# Loop all forms for the given form_name, use the matching one
# otherwise raise an exception
def parse_form(content, form_name):
    parser = FormParser()
    parser.feed(content)
    forms = parser.forms
    num_forms = len(forms)

    if num_forms == 0:
        new_state(2, 'Found no form element in HTML code')

    elif num_forms == 1 and form_name is not None and form_name in forms:
        new_state(
            2,
            'Found one form with name "%s" but expected name "%s"' % (forms.keys()[0], form_name))

    elif num_forms == 1:
        form = forms[forms.keys()[0]]

    elif form_name is None:
        new_state(
            2, 'No form name provided but found multiple forms (Names: %s)' %
            (', '.join(forms.keys())))

    else:
        form = forms.get(form_name)
        if form is None:
            new_state(
                2, 'Found no form with name "%s" (Available: %s)' %
                (form_name, ', '.join(forms.keys())))

    return form


def update_form_vars(form_elem, params):
    v = form_elem['elements'].copy()
    v.update(params)
    return v


opt_debug = False


def main():
    global opt_debug
    short_options = 'I:u:p:H:f:q:e:t:n:sd'
    long_options = ["help"]

    try:
        opts = getopt.getopt(sys.argv[1:], short_options, long_options)[0]
    except getopt.GetoptError, err:
        sys.stderr.write("%s\n" % err)
        sys.exit(1)

    hosts = []
    multiple = False
    uri = '/'
    port = 80
    ssl = False
    form_name = None
    params = {}
    expect_regex = None
    timeout = 10  # seconds
    num_warn = None
    num_crit = None

    for o, a in opts:
        if o in ['-h', '--help']:
            usage()
            sys.exit(0)
        elif o == '-I':
            hosts.append(a)
        elif o == '-u':
            uri = a
        elif o == '-p':
            port = int(a)
        elif o == '-s':
            ssl = True
        elif o == '-f':
            form_name = a
        elif o == '-q':
            params = dict([parts.split('=', 1) for parts in a.split('&')])
        elif o == '-e':
            expect_regex = a
        elif o == '-t':
            timeout = int(a)
        elif o == '-n':
            if ',' in a:
                num_warn, num_crit = map(int, a.split(',', 1))
            else:
                num_warn = int(a)
        elif o == '-d':
            opt_debug = True

    if not hosts:
        sys.stderr.write('Please provide the host to query via -I <HOSTADDRESS>.\n')
        usage()
        sys.exit(1)

    if len(hosts) > 1:
        multiple = True

    try:
        client = init_http()

        states = {}
        for host in hosts:
            base_url = get_base_url(ssl, host, port)
            try:
                # Perform first HTTP request to fetch the page containing the form(s)
                _code, real_url, body = open_url(client, base_url + uri, timeout=timeout)

                form = parse_form(body, form_name)

                # Get all fields and prefilled values from that form
                # Put the values of the given query params in these forms
                form_vars = update_form_vars(form, params)

                # Issue a HTTP request with those parameters
                # Extract the form target and method
                method = form['attrs'].get('method', 'GET').upper()
                target = form['attrs'].get('action', real_url)
                if target[0] == '/':
                    # target is given as absolute path, relative to hostname
                    target = base_url + target
                elif target[0] != '':
                    # relative URL
                    target = '%s/%s' % ('/'.join(real_url.rstrip('/').split('/')[:-1]), target)

                _code, real_url, content = open_url(client,
                                                    target,
                                                    method,
                                                    form_vars,
                                                    timeout=timeout)

                # If a expect_regex is given, check wether or not it is present in the response
                if expect_regex is not None:
                    matches = re.search(expect_regex, content)
                    if matches is not None:
                        new_state(0, 'Found expected regex "%s" in form response' % expect_regex)
                    else:
                        new_state(
                            2, 'Expected regex "%s" could not be found in form response' %
                            expect_regex)
                else:
                    new_state(0, 'Form has been submitted')
            except HostResult, e:
                if not multiple:
                    bail_out(*e.result)
                states[host] = e.result

        if multiple:
            failed = [pair for pair in states.items() if pair[1][0] != 0]
            success = [pair for pair in states.items() if pair[1][0] == 0]
            max_state = max([state for state, _output in states.values()])

            sum_state = 0
            if num_warn is None and num_crit is None:
                # use the worst state as summary state
                sum_state = max_state

            elif num_crit is not None and len(success) <= num_crit:
                sum_state = 2

            elif num_warn is not None and len(success) <= num_warn:
                sum_state = 1

            txt = '%d succeeded, %d failed' % (len(success), len(failed))
            if failed:
                txt += ' (%s)' % ', '.join(['%s: %s' % (n, msg[1]) for n, msg in failed])
            bail_out(sum_state, txt)

    except Exception, e:
        if opt_debug:
            raise
        bail_out(3, 'Exception occured: %s\n' % e)


if __name__ == "__main__":
    main()
