#!/usr/bin/python

## bat-romfsck
## Copyright 2013 Armijn Hemel for Tjaldur Software Governance Solutions
## Licensed under GNU GPLv2

'''
Reimplementation of Harald Welte's romfsck.c which segfaults when encountering
a file that is not romfs. This implementation is hopefully a bit more robust.

See Documentation/filesystems/romfs.txt in the Linux kernel source tree for
the exact format.
'''

import os, sys, os.path, struct, shutil
from optparse import OptionParser

## The checksum in romfs is basically:
## * chop the data up in int32 (big endian)
## * add them
## * squeeze it in an int32 and just use the remainder
def computechecksum(romfsbytes):
	total = 0
	int32 = pow(2,32)
	for i in range(0, len(romfsbytes), 4):
		total = total + struct.unpack('>L', romfsbytes[i:i+4])[0]
	return total%int32

## Unpack the file system, do basic sanity checks first.
def unpackromfs(romfsfile, unpackdir):
	## 1. check the size of the file. The minumsize of romfs
	## is 32 bytes
	if os.stat(romfsfile).st_size < 32:
		return None

	## read parts of the file system and do some sanity checks.
	romfs = open(romfsfile)
	romfs.seek(0)
	romfsdata = romfs.read(8)

	## 2. first verify the first 8 bytes are -rom1fs-
	if romfsdata[:8] != '-rom1fs-':
		romfs.close()
		return None

	romfsdata = romfs.read(4)
	## 3. verify the size of the header. If it has some
	## bizarre value (like bigger than the file it can unpack)
	## it is not a valid romfs file system
	romfssize = struct.unpack('>L', romfsdata)[0]
        if romfssize > os.stat(romfsfile).st_size:
		romfs.close()
		return None

	romfsdata = romfs.read(4)
	## 4. extract checksum
	romfschecksum = struct.unpack('>L', romfsdata)[0]

	romfs.seek(0)
	romfsdata = romfs.read()
	romfs.close()

	## 5. the file name is always padded to 16 bytes. Verify that
	## the last of filename + padding is actually 0x00 and verify that
	## everything from end of the file name to end of padding is also 0x00
	romfsfilenameend = romfsdata.find('\x00', 16)
	if (romfsfilenameend - 16)%16 != 0:
		fileheaderstart = ((romfsfilenameend/16) + 1) * 16
		if not list(set(romfsdata[romfsfilenameend:fileheaderstart])) == ['\x00']:
			return None
	else:
		fileheaderstart = romfsfilenameend + 1

	## romfs documentation talks about "first 512 bytes"
	## but it is unclear what that means to me. Does this include
	## the header? The file name? The raw data of the files without
	## the headers? Questions, questions...
	## TODO
	# romfsbytes = romfsdata[fileheaderstart:fileheaderstart+512]
	# computechecksum(romfsbytes)

	## keep a list of old values to return to when cleaning up
	## to help unpacking directories
	returnunpackdirs = []
	returnoffsets = []

	## offsets plus names for dealing properly hardlinks
	offsetnames = {}

	toplevellen = len(unpackdir)

	romfsfilesize = os.stat(romfsfile).st_size
	while fileheaderstart != 0:
		## check for bogus values
		if fileheaderstart > romfsfilesize:
			return None

		## Get the pointer to the next entry in the file system. Since the four least
		## significant bytes are reused for the RWX bytes and file types some fumbling
		## is needed.
		rwxplusheader = struct.unpack('>L', romfsdata[fileheaderstart:fileheaderstart+4])[0]
		nextheader = (rwxplusheader/16) * 16
		rwxtypes = rwxplusheader % 16

		## If a file is executable this byte is set. Not needed for unpacking, so skip for now.
		# rwx = rwxtypes/8

		## the type of the file can be found in the remainder
		types = rwxtypes % 8
		specialinfo = struct.unpack('>L', romfsdata[fileheaderstart+4:fileheaderstart+8])[0]

		## the size of the file. This should only be set for symbolic links and for
		## regular files.
		filesize = struct.unpack('>L', romfsdata[fileheaderstart+8:fileheaderstart+12])[0]

		## check if there are no bogus values
		if filesize > romfsfilesize:
			return None

		## Extract checksum if we ever want to use it. I don't think it is
		## used in the Linux kernel driver either.
		#filechecksum = struct.unpack('>L', romfsdata[fileheaderstart+12:fileheaderstart+16])[0]

		filenameend = romfsdata.find('\x00', fileheaderstart+16)
		filename = romfsdata[fileheaderstart+16:filenameend]

		if filenameend % 16 != 0:
			filenameend = ((filenameend/16) + 1) * 16

		## . and .. are a bit special
		if filename == '.' or filename == '..':
			## if . and .. are the only entries in the directory the
			## old values for the directories and offset for the next
			## directory should be restored.
			if nextheader == 0:
				if returnoffsets != [] and returnunpackdirs != []:
					nextheader = returnoffsets.pop()
					unpackdir = returnunpackdirs.pop()
			fileheaderstart = nextheader
			continue

		## hardlink
		if types == 0:
			if filesize != 0:
				return None
			if offsetnames.has_key(specialinfo):
				oldcwd = os.getcwd()
				os.chdir(unpackdir)
				os.link(offsetnames[specialinfo], filename)
				os.chdir(oldcwd)
			else:
				## Does this really ever happen?
				## Testing with various file systems suggests
				## this is not the case.
				pass

		## entry is a directory. Create it in the output directory
		## and record the old directory and offset, so it is possible
		## to recurse and recreate the state later on.
		if types == 1:
			if filesize != 0:
				return None
			if nextheader != 0:
				returnoffsets.append(nextheader)
				returnunpackdirs.append(unpackdir)
			unpackdir = os.path.join(unpackdir, filename)
			try:
				os.makedirs(unpackdir)
			except:
				pass
			nextheader = specialinfo

		## the file is a normal file so it can be written out to
		## unpackdir
		elif types == 2:
			## sanity check, specialinfo has to be 0
			if specialinfo != 0:
				return None
			outfile = open(os.path.join(unpackdir, filename), 'w')
			outfile.write(romfsdata[filenameend:filenameend+filesize])
			outfile.close()

			offsetnames[fileheaderstart] = os.path.join(unpackdir[toplevellen:], filename)

			## If this is the last file in a subdirectory the previous
			## values for the directories and offset for the next
			## subdirectory should be restored.
			if nextheader == 0:
				if returnoffsets != [] and returnunpackdirs != []:
					nextheader = returnoffsets.pop()
					unpackdir = returnunpackdirs.pop()

		## symbolic link. This link should be recreated
		elif types == 3:
			## sanity check, specialinfo has to be 0
			if specialinfo != 0:
				return None
			source = romfsdata[filenameend:filenameend+filesize]
			oldcwd = os.getcwd()
			os.chdir(unpackdir)
			os.symlink(source, filename)
			os.chdir(oldcwd)
			if nextheader == 0:
				if returnoffsets != [] and returnunpackdirs != []:
					nextheader = returnoffsets.pop()
					unpackdir = returnunpackdirs.pop()

		## block device, skip
		elif types == 4:
			## sanity check, size has to be 0
			if filesize != 0:
				return None
		## char device, skip
		elif types == 5:
			## sanity check, size has to be 0
			if filesize != 0:
				return None
		## socket, skip
		elif types == 6:
			## sanity checks, size and specialinfo have to be 0
			if filesize != 0:
				return None
			if specialinfo != 0:
				return None
		## fifo, skip
		elif types == 7:
			## sanity checks, size and specialinfo have to be 0
			if filesize != 0:
				return None
			if specialinfo != 0:
				return None
		fileheaderstart = nextheader
	return True

def main(argv):
	## parameter to emulate old behaviour of romfsck. Set to False by default. Change to True if you need
	## the old behaviour
	oldskool = False
	if not oldskool:
		parser = OptionParser()
		parser.add_option("-b", "--binary", action="store", dest="romfsfile", help="path to binary file", metavar="FILE")
		parser.add_option("-d", "--directory", action="store", dest="unpackdir", help="path to unpacking directory", metavar="DIR")
		(options, args) = parser.parse_args()

		if options.romfsfile == None:
			parser.error("Path to romfs file needed")
		elif not os.path.exists(options.romfsfile):
			parser.error("romfs file does not exist")
		else:
			romfsfile = options.romfsfile
		if options.unpackdir == None:
			parser.error("Path to unpack directory needed")
		elif not os.path.exists(options.unpackdir):
			parser.error("unpack directory does not exist")
		if os.listdir(options.unpackdir) != []:
			parser.error("unpack directory %s not empty" % options.unpackdir)
		else:
			unpackdir = options.unpackdir
	else:
		## This is so ugly
		## argv[0] == programname
		## argv[1] == '-x'
		## argv[2] == unpackdir
		## argv[3] == romfsfile
		if len(argv) != 4:
			print >>sys.stderr, "Error: wrong number of parameters"
			sys.exit(1)
		if argv[1] != '-x':
			print >>sys.stderr, "Error: wrong invocation of command"
			sys.exit(1)
	
		unpackdir = argv[2]
		if not os.path.exists(unpackdir):
			print >>sys.stderr, "Error: directory %s does not exist" % unpackdir
			sys.exit(1)
		if os.listdir(unpackdir) != []:
			print >>sys.stderr, "Error: unpack directory %s not empty" % unpackdir
			sys.exit(1)

		romfsfile = argv[3]
		if not os.path.exists(romfsfile):
			print >>sys.stderr, "Error: file %s does not exist" % romfsfile
			sys.exit(1)

	res = unpackromfs(romfsfile, unpackdir)
	if res == None:
		unpackdirlist = os.listdir(unpackdir)
		if unpackdirlist != []:
			for u in unpackdirlist:
				## first try unlinking the file
				try:
					os.unlink(u)
				except:
					## obviously it didn't work, so it has to be a directory tree
					try:
						shutil.rmtree(u)
					except:
						print >>sys.stderr, "can't remove %s" % u
		sys.exit(1)

if __name__ == "__main__":
	main(sys.argv)
