#!/usr/bin/python
import sys
import logging
logging.basicConfig()

from switzerland.client.PacketDiff import PacketDiffer
from switzerland.common.Dummies    import DummyAlice

import switzerland.lib.shrunk_scapy.utils       as scapy_utils
import switzerland.lib.shrunk_scapy.layers.inet as scapy_inet

dummy = DummyAlice

print_unmatched = True
print_firewalled = True

def check_args():
    try:
        assert len(sys.argv) > 1 
        for arg in sys.argv[1:]:
            assert "-in.pcap" in arg
    except:
        print "Usage:"
        print sys.argv[0], "<-in.pcap file> [-in.pcap file...]" 
        sys.exit(1)

import re
in_re = re.compile("-in")
def handle_file(file):
    """
    Look through a -in file (and its paired -out file, if there is one) and
    try to determine which of the -out packets might have been modified to
    produce the forgery.
    """
    file2 = in_re.sub("-out",file)
    packets1 = scapy_utils.rdpcap(file)
    try:
        packets2 = scapy_utils.rdpcap(file2)
    except IOError:
        if print_unmatched:
            print "The -in file is not accompanied by a -out file;",
            print "the packet is probably injected:"
            print packets1[0].summary()
        return

    results = compare_pcaps(packets1,packets2)
    if results == -1:
        # firewalled
        return

    print "Sent logs: %d packets; Rec'd logs: %d packets" % \
          (len(packets1), len(packets2))
    if not results:
        print "Probably a spoofed packet or 3rd party retransmission:\n"
        print packets1[0].summary()
    else:
        print "------------Modified packet--------------"
        try:
            seq = "seq: " + `packets1[0].seq`
        except:
            seq = ""
        print "Received:", packets1[0].summary(), "id:", packets1[0].id, seq
        for n in xrange(len(results)):
            if n > 0:
                print "** Another packet that might have been the one sent:\n"
            recd, sent =  results[n]
            print "latency:", recd.time - sent.time
            print PacketDiffer(str(sent), str(recd), dummy).diff()

def compare_pcaps(packets1, packets2):
    target = packets1[0]
    target_ipid = target.id
    tp =target.payload

    example = packets2[0]
    ep = example.payload
    assert type(tp) == type(ep) == scapy_inet.IP

    if tp.src != ep.src or tp.dst != ep.dst:
        print "Firewalled", tp.src, tp.dst, ep.src, ep.dst
        if not print_firewalled:
            return -1

    results = []
    for p in packets2:
        if p.id == target_ipid:
            results.append( (target, p) )
    return results

def main():
    check_args()
    for file in sys.argv[1:]:
        handle_file(file)

if __name__ == "__main__":
    main()
