#!/usr/bin/python3 -u

import argparse
import hashlib
import os
import re
import shutil
import subprocess
import sys
import time
import urllib.parse
import urllib.request
import zipfile
import zlib
from pathlib import Path
from tqdm import tqdm

TMP_DIR = os.path.join('/tmp', os.path.basename(sys.argv[0]) + '-' + str(os.getpid()))
CACHE_DIR = os.path.join(os.environ['HOME'], '.cache', 'dosemu')

eight_dot_three_matcher = re.compile(r'^\S{1,8}(\.\S{1,3})?$')

def is_eight_dot_three_filename(filename):
    return eight_dot_three_matcher.match(filename)

def download_file(source, destination_directory):
    destination_file = urllib.parse.unquote(source.split("/")[-1])
    if (is_eight_dot_three_filename(destination_file)):
        destination_file = destination_file.lower()
    destination_file = os.path.join(destination_directory, destination_file)
    if not Path(destination_file).is_file():
        print('Downloading ' + source + '...')
        with urllib.request.urlopen(source) as response, open(destination_file, 'wb') as f:
            if hasattr(response, 'headers'):
                length = response.headers['Content-length']
            if length == None:
                length = response.getheader('content-length')
            if length == None:
                length = "0"
            if type(length) is str:
                with tqdm(total=int(length), unit='B') as pbar:
                    while True:
                        chunk = response.read(1024)
                        if not chunk:
                            break
                        f.write(chunk)
                        f.flush()
                        pbar.update(len(chunk))
                        pbar.refresh()
            else:
                f.write(response.read())
    return destination_file

def download_files(urls, destination):
    destination_files = []
    for url in urls:
        destination_files.append(download_file(url, destination))
    return destination_files

def calculate_sha256sum(filename):
    with open(filename,"rb") as f:
        return hashlib.sha256(f.read()).hexdigest();

def assert_sha256sum(file, sha256sum):
    val = calculate_sha256sum(file)
    if val == sha256sum:
        print('Verified checksum ' + str(file))
    else:
        print('The downloaded file ' + str(file) + ' could not be verified.')
        print('Actual   SHA256SUM: ' + val)
        print('Expected SHA256SUM: ' + sha256sum)
        print('Please remove the file and rerun the script.')
        sys.exit(1)

def extract_zipfile_entry(archive, zipinfo, output_filename):
    if Path(output_filename).exists():
        with open(output_filename, 'rb') as existing_file:
            existing_file_crc32 = zlib.crc32(existing_file.read())
            existing_file.close()
        if existing_file_crc32 == zipinfo.CRC:
            print('File with equal content ' + output_filename + ' already exists, skipping...')
            return
        else:
            print('\nExisting file with different content found:\n\t' + output_filename + """
Likely another DOS version is installed already. Remove the
existing installation with rmfdusr and rerun this script to proceed.""")
            sys.exit(1)
    with open(output_filename, 'wb') as output_file:
        output_file.write(archive.read(zipinfo.filename))
        output_file.close()
    os.utime(output_filename, (time.time(), time.mktime(zipinfo.date_time + (0, 0, -1))))

def process_disk_image_archive(archive, destination):
    zip_archive = zipfile.ZipFile(archive)
    # zip file with disk images
    if re.search(r'(?i).im[ag]$', zip_archive.namelist()[0]):
        zip_archive.extractall(TMP_DIR)
        for image_filename in zip_archive.namelist():
            image_path = os.path.join(TMP_DIR, image_filename)
            subprocess.run(['mcopy', '-sn', '-i', image_path, '::*', destination], env={"PATH": os.environ['PATH'], "MTOOLS_LOWER_CASE": '1'})
    else:
        for entry in zip_archive.infolist():
            print("extracting " + entry.filename)
            filename = entry.filename.lower()
            output_filename = os.path.join(destination, filename)
            extract_zipfile_entry(zip_archive, entry, output_filename)

def download_and_process(imgurls, name, destination, sha256sums = []):
    destination_dir = os.path.join(CACHE_DIR, name)
    if not Path(destination_dir).is_dir():
        os.makedirs(destination_dir)
    destination_files = download_files(imgurls, destination_dir)
    os.mkdir(TMP_DIR)
    for i, imgurl in enumerate(imgurls):
        if len(sha256sums) > 0:
            assert_sha256sum(destination_files[i], sha256sums[i])
    for destination_file in destination_files:
        if re.search(r'(?i).zip$', destination_file) != None:
            process_disk_image_archive(destination_file, destination)
        else:
            shutil.copy(destination_file, destination)
    shutil.rmtree(TMP_DIR)

def verify_tool(executable, package):
    if subprocess.getstatusoutput(['which', executable]) == 0:
        print('Please install ' + package + ' and make sure that ' + executable + ' is on your PATH.')
        sys.exit(1)

def verify_system():
    verify_tool('mcopy', 'mtools')

def derive_url_list(imgurl):
    imgurls = []
    filename = urllib.parse.unquote(imgurl.split("/")[-1])
    if filename.casefold().find("disk") != -1:
        if filename.casefold().find("of") != -1:
            imgcnt = int(re.findall(r'(?i)disk\s?\d+\s?of\s(\d+)', filename)[0])
            for i in range(1, imgcnt+1):
                imgurls.append(urllib.parse.quote_plus(imgurl.rsplit("/", 1)[0] + "/" + re.sub(r'(?i)(?P<one>disk\s?)\d+(?P<two>\s?of\s\d+)', r'\g<one>'+str(i) + r'\g<two>', filename), r"\./_-:"))
        else: # if the filenames have just one number, assume the last disk from the set was provided
            imgcnt = int(re.findall(r'(?i)disk\s?(\d+)', filename)[0])
            for i in range(1, imgcnt+1):
                imgurls.append(urllib.parse.quote_plus(imgurl.rsplit("/", 1)[0] + "/" + re.sub(r'(?i)(?P<one>disk\s?)\d+', r'\g<one>'+str(i), filename), r"\./_-:"))
    else:
        imgurls.append(imgurl)
    return imgurls;

parser = argparse.ArgumentParser(description="""Script to download multiple URLs or a set of disk images.
When a URL to a disk image is provided, based on known filename patterns, all related images are downloaded and extracted.
Either a name should contain \"Disk ? of x\" or \"diskx\" where 1 is the first disk and x is the last disk.
The destination directory will contain all files from all disk images.""")
parser.add_argument("-c", "--custom",    help="Download specified disk/file set, requires a destination directory", nargs='+', type=str)
parser.add_argument("-d", "--destination", help="Specify/override the destination directory", type=str)
parser.add_argument("-s", "--sha256sum", help="Specify one or more sha256sum(s)", nargs='+', type=str)
parser.add_argument("-n", "--name", help="Specify a name to cache files under", type=str)

args = parser.parse_args()
if args.custom and args.destination:
    urls = args.custom
    destination = args.destination
    os.makedirs(destination, exist_ok=True)
    if not os.listdir(destination):
        if len(urls) == 1:
            download_and_process(derive_url_list(urls[0]), args.name, destination, args.sha256sum if args.sha256sum else [])
        else:
            download_and_process(urls, args.name, destination, args.sha256sum if args.sha256sum else [])
    else:
        if args.sha256sum and not list(Path(destination).glob('*.sys')):
            for i, file in enumerate(Path(destination).iterdir()):
                assert_sha256sum(file, args.sha256sum[i])
        else:
            print('Continuing with existing file set in: ' + destination)
