#!/usr/bin/python
# -*- encoding: utf-8; py-indent-offset: 4 -*-
# +------------------------------------------------------------------+
# |             ____ _               _        __  __ _  __           |
# |            / ___| |__   ___  ___| | __   |  \/  | |/ /           |
# |           | |   | '_ \ / _ \/ __| |/ /   | |\/| | ' /            |
# |           | |___| | | |  __/ (__|   <    | |  | | . \            |
# |            \____|_| |_|\___|\___|_|\_\___|_|  |_|_|\_\           |
# |                                                                  |
# | Copyright Mathias Kettner 2014             mk@mathias-kettner.de |
# +------------------------------------------------------------------+
#
# This file is part of Check_MK.
# The official homepage is at http://mathias-kettner.de/check_mk.
#
# check_mk is free software;  you can redistribute it and/or modify it
# under the  terms of the  GNU General Public License  as published by
# the Free Software Foundation in version 2.  check_mk is  distributed
# in the hope that it will be useful, but WITHOUT ANY WARRANTY;  with-
# out even the implied warranty of  MERCHANTABILITY  or  FITNESS FOR A
# PARTICULAR PURPOSE. See the  GNU General Public License for more de-
# tails. You should have  received  a copy of the  GNU  General Public
# License along with GNU Make; see the file  COPYING.  If  not,  write
# to the Free Software Foundation, Inc., 51 Franklin St,  Fifth Floor,
# Boston, MA 02110-1301 USA.

# DB2 support requires installation of the IBM Data Server Client:
#  http://www-01.ibm.com/support/docview.wss?uid=swg27016878
# as well as the ibm_db2 Python DBI driver for DB2:
#  https://pypi.python.org/pypi/ibm_db
"""Check_MK SQL Test"""
from typing import Tuple  # pylint: disable=unused-import
import argparse
import logging
import os
import sys

import cmk.utils.password_store
cmk.utils.password_store.replace_passwords()

LOG = logging.getLogger(__name__)

DEFAULT_PORTS = {
    "postgres": 5432,
    "mysql": 3306,
    "mssql": 1433,
    "oracle": 1521,
    "db2": 50000,
}

MP_INF = (float('-inf'), float('+inf'))  # type: Tuple[float, float]

#   . parse commandline argumens


def levels(values):  # type: (str) -> Tuple[float, float]
    lower, upper = values.split(':')
    _lower = float(lower) if lower else MP_INF[0]
    _upper = float(upper) if upper else MP_INF[1]
    return (_lower, _upper)


def sql_cmd_piece(values):
    """Parse every piece of the SQL command (replace \\n and \\;)"""
    return values.replace(r"\n", "\n").replace(r"\;", ";")


def parse_args(argv):
    """Parse commandline arguments (incl password store and logging set up)"""
    this = str(os.path.basename(argv[0]))
    fmt = argparse.RawDescriptionHelpFormatter
    parser = argparse.ArgumentParser(prog=this, description=__doc__, formatter_class=fmt)
    # flags
    parser.add_argument(
        "-v",
        "--verbose",
        action="count",
        default=0,
        help='''Verbose mode: print SQL statement and levels
                             (for even more output use -vv''',
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help='''Debug mode: let Python exceptions come through''',
    )
    parser.add_argument(
        "-m",
        "--metrics",
        action="store_true",
        help='''Add performance data to the output''',
    )
    parser.add_argument(
        "-o",
        "--procedure",
        action="store_true",
        help='''treat the main argument as a procedure instead
                              of an SQL-Statement''',
    )
    parser.add_argument(
        "-i",
        "--input",
        metavar="CSV",
        default=[],
        type=lambda s: s.split(','),
        help='''comma separated list if values of input variables
                             if required by the procedure''',
    )
    # optional arguments
    parser.add_argument(
        "-d",
        "--dbms",
        default='postgres',
        choices=['postgres', 'mysql', 'mssql', 'oracle', 'db2'],
        help='''Name of the database management system.
                             Default is "postgres"''',
    )
    parser.add_argument(
        "-H",
        "--hostname",
        metavar='HOST',
        default='127.0.0.1',
        help='''Hostname or IP-Address where the database lives.
                             Default is "127.0.0.1"''',
    )
    parser.add_argument(
        "-P",
        "--port",
        default=None,
        type=int,
        help='''Port used to connect to the database.
                             Default depends on DBMS''',
    )
    parser.add_argument(
        "-w",
        "--warning",
        metavar='RANGE',
        default=MP_INF,
        type=levels,
        help='''lower and upper level for the warning state,
                             separated by a colon''',
    )
    parser.add_argument(
        "-c",
        "--critical",
        metavar='RANGE',
        default=MP_INF,
        type=levels,
        help='''lower and upper level for the critical state,
                             separated by a colon''',
    )
    parser.add_argument(
        "-t",
        "--text",
        default='',
        help='''Additional text prefixed to the output''',
    )

    # required arguments
    parser.add_argument(
        "-n",
        "--name",
        required=True,
        help='''Name of the database on the DBMS''',
    )
    parser.add_argument(
        "-u",
        "--user",
        required=True,
        help='''Username for database access''',
    )
    parser.add_argument(
        "-p",
        "--password",
        required=True,
        help='''Password for database access''',
    )
    parser.add_argument(
        "cmd",
        metavar="SQL-Statement|Procedure",
        type=sql_cmd_piece,
        nargs='+',
        help='''Valid SQL-Statement for the selected database.
                             The statement must return at least a number and a
                             string, plus optional performance data.

                             Alternatively: If the the "-o" option is given,
                             treat the argument as a procedure name.

                             The procedure must return one output variable,
                             which content is evaluated the same way as the
                             output of the SQL-Statement''',
    )
    args = parser.parse_args(argv[1:])
    args.cmd = ' '.join(args.cmd)

    # LOGGING
    fmt = "%(message)s"
    if args.verbose > 1:
        fmt = "%(levelname)s: %(lineno)s: " + fmt
        if args.dbms == "mssql":
            os.environ["TDSDUMP"] = "stdout"
    logging.basicConfig(level=max(30 - 10 * args.verbose, 0), format=fmt)

    # V-VERBOSE INFO
    for key, val in args.__dict__.items():
        if key in ('user', 'password'):
            val = '****'
        LOG.debug('argparse: %s = %r', key, val)
    return args


#.


def bail_out(exit_code, output):
    state_txt = ["OK", "WARN", "CRIT", "UNKNOWN"][exit_code]
    sys.stdout.write("%s - %s\n" % (state_txt, output))
    sys.exit(exit_code)


#   . DBMS specific code here!
#
# For every DBMS specify a connect and execute function.
# Add them in the dict in the 'main' connect and execute functions
#
def _default_execute(cursor, cmd, inpt, procedure):
    if procedure:
        LOG.info("SQL Procedure Name: %s", cmd)
        LOG.info("Input Values: %s", inpt)
        cursor.callproc(cmd, inpt)
        LOG.debug("inpt after 'callproc' = %r", inpt)
    else:
        LOG.info("SQL Statement: %s", cmd)
        cursor.execute(cmd)

    return cursor.fetchall()


def postgres_connect(host, port, db_name, user, pwd):
    import psycopg2
    return psycopg2.connect(host=host, port=port, \
                            database=db_name, user=user, password=pwd)


def postgres_execute(cursor, cmd, inpt, procedure):
    return _default_execute(cursor, cmd, inpt, procedure)


def mysql_connect(host, port, db_name, user, pwd):
    import MySQLdb
    return MySQLdb.connect(host=host, port=port, db=db_name, \
                           user=user, passwd=pwd)


def mysql_execute(cursor, cmd, inpt, procedure):
    return _default_execute(cursor, cmd, inpt, procedure)


def mssql_connect(host, port, db_name, user, pwd):
    import pymssql
    return pymssql.connect(host=host, port=port, \
                           database=db_name, user=user, password=pwd)


def mssql_execute(cursor, cmd, _inpt, procedure):
    if procedure:
        LOG.info("SQL Procedure Name: %s", cmd)
        cmd = 'EXEC ' + cmd
    else:
        LOG.info("SQL Statement: %s", cmd)

    cursor.execute(cmd)

    return cursor.fetchall()


def oracle_connect(host, port, db_name, user, pwd):
    sys.path.append('/usr/lib/python%s.%s/site-packages' %
                    (sys.version_info.major, sys.version_info.minor))
    try:
        import cx_Oracle  # pylint: disable=import-error
    except ImportError as exc:
        bail_out(3, "%s. Please install it via 'pip install cx_Oracle'." % exc)

    cstring = "%s/%s@%s:%s/%s" % (user, pwd, host, port, db_name)
    return cx_Oracle.connect(cstring)


def oracle_execute(cursor, cmd, inpt, procedure):
    try:
        import cx_Oracle  # pylint: disable=import-error
    except ImportError as exc:
        bail_out(3, "%s. Please install it via 'pip install cx_Oracle'." % exc)

    if procedure:
        LOG.info("SQL Procedure Name: %s", cmd)
        LOG.info("Input Values: %s", inpt)

        # In an earlier version, this code-branch
        # had been executed regardles of the dbms.
        # clearly this is oracle specific.
        outvar = cursor.var(cx_Oracle.STRING)  # pylint:disable=undefined-variable
        # However, I have not been able to test it.
        inpt.append(outvar)

        cursor.callproc(cmd, inpt)

        LOG.debug("inpt after 'callproc' = %r", inpt)
        LOG.debug("outvar = %r", outvar)

        # for empty input this is just
        #  _res = outvar.getvalue()
        _res = ','.join(i.getvalue() for i in inpt)
        LOG.debug("outvar.getvalue() = %r", _res)
        params_result = _res.split(",")
        LOG.debug("params_result = %r", params_result)

    else:
        LOG.info("SQL Statement: %s", cmd)
        cursor.execute(cmd)

    return cursor.fetchall()


def db2_connect(host, port, db_name, user, pwd):
    # IBM data server driver
    try:
        import ibm_db  # pylint: disable=import-error
        import ibm_db_dbi  # pylint: disable=import-error
    except ImportError as exc:
        bail_out(3, "%s. Please install it via pip." % exc)

    cstring = 'DRIVER={IBM DB2 ODBC DRIVER};DATABASE=%s;' \
              'HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=%s;' \
              % (db_name, host, port, user, pwd)
    ibm_db_conn = ibm_db.connect(cstring, '', '')
    return ibm_db_dbi.Connection(ibm_db_conn)


def db2_execute(cursor, cmd, inpt, procedure):
    return _default_execute(cursor, cmd, inpt, procedure)


#.


def connect(dbms, host, port, db_name, user, pwd):
    """Connect to the correct database

    A python library is imported depending on the value of dbms.
    Return the created connection object.
    """
    if port is None:
        port = DEFAULT_PORTS[dbms]

    return {
        'postgres': postgres_connect,
        'mysql': mysql_connect,
        'mssql': mssql_connect,
        'oracle': oracle_connect,
        'db2': db2_connect,
    }[dbms](host, port, db_name, user, pwd)


def execute(dbms, connection, cmd, inpt, procedure=False):
    """Execute the sql statement, or call the procedure.

    Some corrections are made for libraries that do not adhere to the
    python SQL API: https://www.python.org/dev/peps/pep-0249/
    """
    cursor = connection.cursor()

    try:
        result = {
            'postgres': postgres_execute,
            'mysql': mysql_execute,
            'mssql': mssql_execute,
            'oracle': oracle_execute,
            'db2': db2_execute,
        }[dbms](cursor, cmd, inpt, procedure)
    finally:
        cursor.close()
        connection.close()

    LOG.info("SQL Result:\n%r", result)
    return result


def process_result(result, warn, crit, metrics=False, debug=False):
    """Process the first row (!) of the result of the SQL command.

    Only the first row of the result (result[0]) is considered.
    It is assumed to be an sequence of length 3, consisting of of
    [numerical_value, text, performance_data].
    The full result is returned as muliline output.
    """
    if not result:
        bail_out(3, "SQL statement/procedure returned no data")
    row0 = result[0]

    number = float(row0[0])

    # handle case where sql query only results in one column
    if len(row0) == 1:
        text = "%s" % row0[0]
    else:
        text = "%s" % row0[1]

    perf = ""
    if metrics:
        try:
            perf = " | performance_data=%s" % str(row0[2])
        except IndexError:
            if debug:
                raise

    state = 0
    if warn != MP_INF or crit != MP_INF:
        if not warn[0] <= number < warn[1]:
            state = 1
        if not crit[0] <= number < crit[1]:
            state = 2
        text += ": %s" % number
    else:  # no levels were given
        if number in (0, 1, 2, 3):
            state = int(number)
        else:
            bail_out(3, "<%d> is not a state, and no levels given" % number)

    return state, text + perf


def main(argv=None):

    args = parse_args(argv or sys.argv)

    try:
        msg = "connecting to database"
        conn = connect(args.dbms, args.hostname, args.port, args.name, args.user, args.password)

        msg = "executing SQL command"
        result = execute(args.dbms, conn, args.cmd, args.input, procedure=args.procedure)

        msg = "processing result of SQL statement/procedure"
        state, text = process_result(
            result,
            args.warning,
            args.critical,
            metrics=args.metrics,
            debug=args.debug,
        )
    except () if args.debug else Exception as exc:
        errmsg = str(exc).strip('()').replace(r'\n', ' ')
        bail_out(3, "Error while %s: %s" % (msg, errmsg))

    bail_out(state, args.text + text)


if __name__ == '__main__':
    main()
