#!/usr/bin/python2
# -*- coding: utf-8 -*-
# Copyright (c) 2009, 2010 Sebastian Wiesner <lunaryorn@googlemail.com>
# Copyright (c) 2010 Jordi Fita <jfita@geishastudios.com>

# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
from lxml import etree
from pygments import lex, highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import get_lexer_by_name
from pygments.token import Token
from lxml.html import fragment_fromstring
import re
from pygments.filter import Filter
import sys

class AtangleFilter(Filter):
    def __init__(self, **options):
        Filter.__init__(self, **options)
        self.directive = r'''^\s*<<(\*|[-\w\s\.]+)>>=?\s*$'''

    def filter(self, lexer, stream):
        lexer_input = []
        line = ""
        for ttype, value in stream:
            if len(value) > 0 and (value[0] == '\n' or value[0] == '\r' or value[0] == '='):
                if re.match(self.directive, line):
                    yield Token.Name.Label, line
                else:
                    for original_ttype, original_value in lexer_input:
                        yield original_ttype, original_value
                # The end of line also needs to be there.
                yield ttype, value
                # start with the next line.
                lexer_input = []
                line = ""
            else:
                lexer_input.append((ttype, value))
                line = line + value;
        for ttype, value in lexer_input:
            yield ttype, value

def html_highlight(context, language, code, config):
    """
    Highlight the given ``code`` in the given ``language``.  ``context`` is
    the XPath context in which this function was applied.  ``config`` is
    ignored.

    Return a list of HTML nodes containing the highlighted code.
    """
    if not code:
        code = context.context_node.xpath('.//text()')
    lexer = get_lexer_by_name(language[0].lower())
    lexer.add_filter(AtangleFilter())
    html = highlight(code[0], lexer, HtmlFormatter(nowrap=True))
    highlight_div = fragment_fromstring(html, create_parent=True)
    highlight_div.set('class', 'pygments_highlight notranslate')
    return [highlight_div]

def apply_xslt(stylesheet, document):
    """
    Transform ``document`` using the given ``stylesheet``.  Both
    must be lxml element trees.

    Return the error log of the transformation.
    """
    # Register extension function for highlighting
    # xhl = etree.FunctionNamespace('http://net.sf.xslthl/ConnectorXalan')
    # xhl.prefix = 'xhl'
    # xhl['highlight'] = html_highlight

    # perform transformation
    transform = etree.XSLT(stylesheet)
    print transform(document)
    return transform.error_log

def print_errors(errors):
    for error in errors:
        if error.type == etree.ErrorTypes.ERR_OK:
            # succes, so just print the message
            tmpl = '{0.message}'
        else:
            # print filename and columns
            tmpl = ('{0.level_name}:{0.filename}:{0.line},{0.column}: '
                    '{0.message} ({0.type_name})')
        print >> sys.stderr, tmpl.format(error)

def main():
    if len(sys.argv) < 3:
        print >> sys.stderr, 'missing arguments'
        return 1
    elif len(sys.argv) > 3:
        print >> sys.stderr, 'too many arguments'
        return 1
    xslt_file, xml_file = sys.argv[1:]

    document = etree.parse(xml_file)
    document.xinclude()
    stylesheet = etree.parse(xslt_file)
    print_errors(apply_xslt(stylesheet, document))
if __name__ == '__main__':
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        pass
