#!/usr/bin/python
# Copyright (c) 2014 Quanta Research Cambridge, Inc
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#

import fcntl, glob, json, struct, subprocess, sys
from gmpy import mpz
BNOC_TRACE = 0xc008b508
Trace_len = 144
struct_traceData = '@ L I I 16I 16I'
BNOC_GET_TLP = 0x8008b507
Tlp_len = 24
BNOC_ENABLE_TRACE = 0x8008b508
Enable_len = 4
PCIE_CHANGE_ENTRY = 0x8008b50f
struct_changeEntry = '@ I I'
ChangeEntry_len = 8

# 64-bit BAR
tlpdatalog = [
    '000000011184ffff4a000001020000040018001068470000',
    '000000020984fff0000000010018000fdf508000e775b7be',
    '000000031184ffff4a0000010200000400180000bad0dada',
    '000000040984fff0000000010018000fdf5040002c7bdba8',
    '000000051184ffff4a0000010200000400180000005a05a0',
    '000000060984fff0000000010018000fdf5000001ce1d27f',
    '000000071184ffff4a0000010200000400180000005b05b0',
    '000000080984fff0000000010018000fdf5040009c773d52',
    '000000091184ffff4a0000010200000400180000005a05a0',
    '0000000a0984fff0000000010018000fdf50c0107d93c3c9',
    '0000000b1184ffff4a000001020000040018001068470000',
    '0000000c0984fff0000000010018000fdf50800099159bf9',
    '0000000d1184ffff4a0000010200000400180000bad0dada',
    '0000000e0984fff0000000010018000fdf5040008fc4124f',
    '0000000f1184ffff4a0000010200000400180000005a05a0',
    '000000100984fff0000000010018000fdf5000000f52fd62',
]

TlpPacketType = [
    'MRW', # 'MEMORY_READ_WRITE'
    'MEMORY_READ_LOCKED',
    'IO_REQUEST',
    'UNKNOWN_TYPE_3',
    'CONFIG_0_READ_WRITE',
    'CONFIG_1_READ_WRITE',
    'UNKNOWN_TYPE_6',
    'UNKNOWN_TYPE_7',
    'UNKNOWN_TYPE_8',
    'UNKNOWN_TYPE_9',
    'COMP',
    'COMPLETION_LOCKED',
    'UNKNOWN_TYPE_12',
    'UNKNOWN_TYPE_13',
    'UNKNOWN_TYPE_14',
    'UNKNOWN_TYPE_15',
    'MSG_ROUTED_TO_ROOT',
    'MSG_ROUTED_BY_ADDR',
    'MSG_ROUTED_BY_ID',
    'MSG_ROOT_BROADCAST',
    'MSG_LOCAL',
    'MSG_GATHER',
    'UNKNOWN_TYPE_22',
    'UNKNOWN_TYPE_23',
    'UNKNOWN_TYPE_24',
    'UNKNOWN_TYPE_25',
    'UNKNOWN_TYPE_26',
    'UNKNOWN_TYPE_27',
    'UNKNOWN_TYPE_28',
    'UNKNOWN_TYPE_29',
    'UNKNOWN_TYPE_30',
    'UNKNOWN_TYPE_31'
]

TlpPacketFormat = [
    'MEM_READ__3DW     ',
    'MEM_READ__4DW     ',
    'MEM_WRITE_3DW_DATA',
    'MEM_WRITE_4DW_DATA',
    'TLP Prefix',
    'UNKNOWN_FMT_5',
    'UNKNOWN_FMT_6',
    'UNKNOWN_FMT_7',
]

Pcie3ChangeSrc = [
    'none',
    'current_speed',
    'dpa_substate_change',
    'err_cor_out',
    'err_fatal_out',
    'err_nonfatal_out',
    'flr_in_process',
    'function_power_state',
    'function_status',
    'hot_reset_out',
    'link_power_state',
    'ltr_enable',
    'ltssm_state',
    'max_payload',
    'max_read_req',
    'negotiated_width',
    'obff_enable',
    'phy_link_down',
    'phy_link_status',
    'pl_status_change',
    'power_state_change_interrupt',
    'rcb_status',
    'backpressure_count'
]

first_vcd_timestamp = mpz(0)
last_vcd_timestamp = mpz(0)
last_vcd_pktclass_code = None

def byteswap(be):
    return '%s%s%s%s' % (be[6:8],be[4:6],be[2:4],be[0:2])

pktclassCodes = {
    'CpuRReq': 'S',
    'CpuWReq': 'T',
    'CpuRResp': 's',
    '(to) slave continuation': 'c',
    'DmaWReq': 'W',
    'DmaRReq': 'M',
    'DmaRResp': 'm',
    '(to) master continuation': 'C',
    'trace': 't',
}

vcd_header_template='''
$version
   tlp.py
$end
$comment
$end
$timescale 8ns $end
$scope module logic $end
%(vars)s
$upscope $end
$enddefinitions $end
'''

unused='''
$dumpvars
%(dumpvars)s
$end
'''

def emit_vcd_header(f):
    f.write(vcd_header_template
            % { 'vars': '\n'.join(['$var wire 1 %s %s $end' % (pktclassCodes[k], k.lower().replace(' ', '_')) for k in pktclassCodes]),
                'dumpvars': '\n'.join(['0%s' % pktclassCodes[k] for k in pktclassCodes])
            })

def emit_vcd_entry(f, timestamp, pktclass):
    global first_vcd_timestamp, last_vcd_timestamp, last_vcd_pktclass_code
    if not timestamp:
        return
    if not first_vcd_timestamp:
        first_vcd_timestamp = timestamp
    #print last_vcd_timestamp, timestamp, (timestamp < last_vcd_timestamp)
    if last_vcd_timestamp and (timestamp < last_vcd_timestamp):
        f.write('$comment %s %s %s $end\n' % (hex(last_vcd_timestamp), hex(timestamp), hex(timestamp + mpz('100000000', 16))))
        timestamp = timestamp + mpz('100000000', 16)
        f.write('$comment %s %s $end\n' % (hex(timestamp), hex(timestamp - first_vcd_timestamp)))

    #timestamp = timestamp - first_vcd_timestamp

    if last_vcd_timestamp and timestamp > (last_vcd_timestamp+1):
        f.write('#%s\n0%s\n' % ((last_vcd_timestamp+mpz(1)), last_vcd_pktclass_code))
    if pktclassCodes.has_key(pktclass):
        pktclass_code = pktclassCodes[pktclass]
        f.write('#%s\n' % timestamp)
        f.write('1%s\n' % pktclass_code)
        if last_vcd_pktclass_code and last_vcd_pktclass_code != pktclass_code:
            f.write('0%s\n' % last_vcd_pktclass_code)
        last_vcd_pktclass_code = pktclass_code
        last_vcd_timestamp = timestamp
    else:
        f.write('$comment %s $end\n' % pktclass)

def pktClassification(tlpsof, tlpeof, tlpbe, pktformat, pkttype, portnum, address):
    ## altera does not fill in the BE bits
    #if tlpbe == '0000':
    #return 'trace'
    if tlpsof == 0:
        if portnum == 4:
            return 'DmaRCon'
        else:
            return 'CpuRCon'
    if portnum == 4:
        if pkttype == 10: # COMPLETION
            return 'DmaRRsp'
        else:
            if pktformat == 2 or pktformat == 3:
                return 'CpuWReq'
            else:
                return 'CpuRReq'
    elif portnum == 8:
        if pkttype == 10: # COMPLETION
            return 'CpuRRsp'
        else:
            if pktformat == 2 or pktformat == 3:
                if (address[0:3] == 'fee'):
                    return 'Interru'
                else:
                    return 'DmaWReq'
            else:
                return 'DmaRReq'
    else:
        return 'Misc'

classCounts = {}
last_seqno = mpz(-1)

def interfaceId(interfaceNumber):
    if (interfaceNumber >= 0 and interfaceNumber < len(fooMap)):
        return fooMap[interfaceNumber]
    else:
        return '%x' % interfaceNumber

def interfaceName(interfaceNumber):
    if (interfaceNumber >= 0 and interfaceNumber < len(fooMap)):
        return interfaceMap[fooMap[interfaceNumber]]
    else:
        return '%x' % interfaceNumber

def print_tlp(tlpdata, f=None):
    global last_seqno
    def segment(i):
        return tlpdata[i*8:i*8+8]
    def byteswap(w):
        def byte(i):
            return w[i*2:i*2+2]
        return ''.join(map(byte, [3,2,1,0]))

    words = map(segment, [0,1,2,3,4,5])

    seqno = mpz(tlpdata[0:8],16)
    if last_seqno >= 0:
        delta = seqno - last_seqno
    else:
        delta = 0
    tlpsof = int(tlpdata[9:10],16) & 1
    tlpeof = int(tlpdata[10:12],16) >> 7
    tlpbe  = tlpdata[12:16]
    tlphit = int(tlpdata[10:12],16) & 0x7f
    pktformat = (int(tlpdata[16:17],16) >> 1) & 7
    pkttype = (int(tlpdata[16:18],16) & 0x1f)

    address=0
    if TlpPacketFormat[pktformat] == 'MEM_READ__4DW     ' or TlpPacketFormat[pktformat] == 'MEM_WRITE_4DW_DATA':
        address = tlpdata[32:]
    elif TlpPacketFormat[pktformat] == 'MEM_READ__3DW     ' or TlpPacketFormat[pktformat] == 'MEM_WRITE_3DW_DATA':
        address = tlpdata[32:40]

    portnum = int(tlpdata[8:10],16) >> 1
    pktclass = pktClassification(tlpsof, tlpeof, tlpbe, pktformat, pkttype, portnum, address)
    if classCounts.has_key(pktclass):
       classCounts[pktclass] += 1
    else:
       classCounts[pktclass] = 1

    if f:
        emit_vcd_entry(f, seqno, pktclass)

    headerstr = tlpdata
    headerstr = ''
    headerstr = headerstr + '%6s' % (pktclass)
    if tlpsof:
        headerstr = headerstr + ':%4s:%18s' % (TlpPacketType[pkttype], TlpPacketFormat[pktformat])
    else:
        headerstr = headerstr + '                        '
    headerstr = headerstr + ' ' + tlpdata[8:10] + ' ' + hex(int(tlpdata[8:10],16) >> 1)
    headerstr = headerstr + ' tlp(%s %d %d %d)' % (tlpbe, tlphit, tlpeof, tlpsof)
    address = -1
    if tlpsof == 0:
        headerstr = headerstr + '                            data:' + tlpdata[16:]
    elif TlpPacketFormat[pktformat] == 'MEM_WRITE_3DW_DATA' and TlpPacketType[pkttype] == 'COMP':
        headerstr = headerstr + '                        tag:' + tlpdata[36:38]
        headerstr = headerstr + ' ' + tlpdata[32:36]
        headerstr = headerstr + ' ' + tlpdata[24:28]
        headerstr = headerstr + ' ' + tlpdata[28:29]
        headerstr = headerstr + ' ' + tlpdata[20:21] + str(int(tlpdata[20:21],16) >> 3)
        headerstr = headerstr + ' ' + tlpdata[29:32]
        headerstr = headerstr + ' ' + tlpdata[38:40]
        headerstr = headerstr + ' %3d' % (int(tlpdata[21:24],16) & 0x3ff)
        headerstr = headerstr + ' ' + byteswap(tlpdata[40:])
    elif TlpPacketFormat[pktformat] == 'MEM_READ__3DW     ' or TlpPacketFormat[pktformat] == 'MEM_WRITE_3DW_DATA':
        address = tlpdata[32:40]
        headerstr = headerstr + '  ' + address
        address = int(address,16)
        headerstr = headerstr + ' %4x'% ((address >> 2) % 8192)
        headerstr = headerstr + ' be(' + tlpdata[31:32] + ' ' + tlpdata[30:31] + ')'
        headerstr = headerstr + ' tag:' + tlpdata[28:30]
        headerstr = headerstr + ' ' + tlpdata[24:28]
        headerstr = headerstr + '                  %3d' % (int(tlpdata[21:24],16) & 0x3ff)
        if TlpPacketFormat[pktformat] == 'MEM_WRITE_3DW_DATA':
            headerstr = headerstr + ' ' + byteswap(tlpdata[40:])
    elif TlpPacketFormat[pktformat] == 'MEM_READ__4DW     ' or TlpPacketFormat[pktformat] == 'MEM_WRITE_4DW_DATA':
        address = tlpdata[32:]
        headerstr = headerstr + ' address: ' + address
        address = int(address,16)
        headerstr = headerstr + ' be(1st: ' + tlpdata[31:32] + ' last:' + tlpdata[30:31] + ')'
        headerstr = headerstr + ' tag:' + tlpdata[28:30]
        headerstr = headerstr + ' reqid:' + tlpdata[24:28]
        headerstr = headerstr + ' length:' + str(int(tlpdata[21:24],16) & 0x3ff)
    elif TlpPacketFormat[pktformat] == 'TLP Prefix':
        headerstr = headerstr + tlpdata
    else:
        headerstr = headerstr + '  tlp data:' + tlpdata[40:]
        headerstr = headerstr + 'lower addr:' + tlpdata[38:40]
        headerstr = headerstr + '       tag:' + tlpdata[36:38]
        headerstr = headerstr + '     reqid:' + tlpdata[34:36]
        headerstr = headerstr + ' bytecount:' + '0x' + tlpdata[33:34]
        headerstr = headerstr + '       bcm:' + str(int(tlpdata[32:33], 16) & 1)
        headerstr = headerstr + '   cstatus:' + str((int(tlpdata[32:33], 16) >> 1) & 7)
        headerstr = headerstr + '    cmplid:' + tlpdata[30:32]
        headerstr = headerstr + '    cmplen:' + tlpdata[27:30]
        headerstr = headerstr + '   nosnoop:' + str(int(tlpdata[26:27],16) & 1)
        headerstr = headerstr + '   relaxed:' + str(int(tlpdata[26:27],16) & 2)
        headerstr = headerstr + '   poison:' + str(int(tlpdata[26:27],16) & 4)
        headerstr = headerstr + '   digest:' + str(int(tlpdata[26:27],16) & 8)
        headerstr = headerstr + '     zero:' + tlpdata[25:26]
        headerstr = headerstr + '   tclass:' + tlpdata[24:25]
        headerstr = headerstr + '  pkttype:' + str(int(tlpdata[22:24],16) & 0x1f) + ' ' + TlpPacketType[int(tlpdata[22:24],16) & 0x1f]
        headerstr = headerstr + '  format:' + str((int(tlpdata[22:24],16) >> 1) & 3) + ' ' + TlpPacketFormat[(int(tlpdata[22:24],16) >> 1) & 3]
    if portnum == 4:
        dir = 'RX'
    elif portnum == 8:
        dir = 'TX'
    else:
        dir = '__'
    if tlpsof == 0:
        dir = dir + 'cc'    # continuation
    elif pkttype == 10: 
        dir = dir + 'pp'    # response
    else:
        dir = dir + 'qq'    # request
    smeth = ''
    if address != -1 and address < baseLimit and interfaceMap != {}:
        address = address - base
        interfaceNumber = address / 0x10000
        offset = address - interfaceNumber * 0x10000
        methodNumber = offset / 0x100
        #print 'AA', "%6x" % address, "%2x" % interfaceNumber, "%2x" % methodNumber
        t = interfaceName(interfaceNumber)
        if methodNumber == 0:
            smeth = '/dir'
        elif methodNumber <= len(t):
            offset = offset - methodNumber * 0x100
            smeth = '/'+t[methodNumber-1]
        smeth = ('%02d' % interfaceNumber) + '/' + interfaceId(interfaceNumber) + smeth + '/' + "%x" % (offset/4)
    print dir, '%10d %10d %s' % (seqno, delta, headerstr), smeth
    #print '                      ' + tlpdata[0:8] + ' ' + tlpdata[8:]
    if len(tlpdata) != 48:
        print 'bogus len', len(tlpdata)
        sys.exit(1)
    last_seqno = seqno

def print_tlp_log(tlplog, f=None, lf=None):
    if f:
        emit_vcd_header(f)
    #ts     delta           response   foo XXX tlp(be hit eof sof) pkttype format             address  off be(1st last) tag req clid stat nosnoop bcnt laddr length data 
    print '             ts     delta   response                     XXX          tlp          address  off   be       tag     clid  nosnp  laddr        data'
    print '                                pkttype format               foo (be hit eof sof)            (1st last)        req     stat  bcnt    length'
    for tlpdata in tlplog:
        if tlpdata.startswith('00000000') or tlpdata == '':
            continue
        if lf:
            lf.write(tlpdata+'\n')
        print_tlp(tlpdata, f)

if __name__ == '__main__':
    argindex = 1
    argcount = len(sys.argv)
    jsondata = {}
    interfaceMap = {}
    if argcount >= 2 and sys.argv[argindex] == '-j':
        fooMap = sys.argv[argindex+1].split(':')
        print 'FFF', fooMap
        j2file = open(sys.argv[argindex + 2]).read()
        jsondata = json.loads(j2file)
        argindex = argindex + 3
        argcount = argcount - 3
        for item in jsondata['interfaces']:
            interfaceMap[item['name']] = [mitem['name'] for mitem in item['decls']]
    devfilenames = glob.glob('/dev/portal_*')
    print 'pcieflat: devices are', devfilenames
    if argcount == 2:
        tlplog = open(sys.argv[argindex]).read().split('\n')
    elif False:
        tlplog = subprocess.check_output(['connectalutil', devfilenames[0]]).split('\n')
        tlpfirst = tlplog.pop(0)
        base = int(tlpfirst[0:16],16)
        base = int(tlpfirst[0:8],16)
    else:
        fd = open(devfilenames[0], 'rw')
        try:
            while 1:
                buf = fcntl.ioctl(fd, PCIE_CHANGE_ENTRY, ' ' * ChangeEntry_len)
                (ts,v) = struct.unpack_from(struct_changeEntry, buf)
                if (v == 0 or v == 0xbad0add0 or ts == 0xbad0add0):
                    break
                src = v & 0xff
                value = v >> 8
                print 'pcie status change: %12d src=%20s:%02d value=%#06x' % (ts, Pcie3ChangeSrc[src], src, value)
                pass
        except:
            pass
        buf = fcntl.ioctl(fd, BNOC_TRACE, ' ' * Trace_len)
        traceData = struct.unpack_from(struct_traceData, buf)
        #print 'SSS', traceData
        base = traceData[0]
        #print ('base %x' % base), 'trace', traceData[1], 'len', traceData[2], 'intval', ['%x' % p for p in traceData[3:3+16]], 'name', traceData[19:19+16]
        tlplog = []
        tlplog.append('%016x' % traceData[0])
        counter = traceData[2]
        while counter > 0:
            counter = counter - 1
            buf = fcntl.ioctl(fd, BNOC_GET_TLP, ' ' * Tlp_len)
            foo = ''
            for x in buf:
                foo = "%02x" % ord( x ) + foo
            tlplog.append(foo)
        fcntl.ioctl(fd, BNOC_ENABLE_TRACE, 1)
        #for foo in tlplog:
        #    print foo
        #sys.exit(1)
    tlplog.sort()
    lf = open('tlp.log', 'w')
    f = open('tlp.vcd', 'w')
    baseLimit = base + 16 * 0x10000
    print_tlp_log(tlplog, f, lf)
    print classCounts
    print sum([ classCounts[k] for k in classCounts])
    f.close()
    lf.close()
