#!/usr/bin/python3

import argparse
import hashlib
import os
import re
import shutil
import subprocess
import sys
import urllib.parse
import urllib.request
import zipfile
from pathlib import Path
from tqdm import tqdm

def check_sum_pkg_info(destination_file):
    pkg_info_file = open(destination_file[:-3] + 'txt', 'r')
    unchecked = True
    for line in pkg_info_file.readlines():
        if line.startswith('SHA'):
            assert_sha256sum(destination_file, line[4:].strip())
            unchecked = False
            break
    if unchecked:
        print('Warning no SHA checksum in pkg-info. Could not verify checksum of ' + destination_file)

def check_sum(destination_file):
    if Path(destination_file[:-3] + 'txt').is_file():
        check_sum_pkg_info(destination_file)
    else:
        print('Warning could not verify checksum of ' + destination_file)

def download_file(source, destination_directory):
    destination_file = os.path.join(destination_directory, urllib.parse.unquote(source.split("/")[-1]))
    if Path(destination_file).is_file():
        if not destination_file.endswith('.txt'):
            check_sum(destination_file)
    else:
        print('Downloading ' + source + '...')
        with urllib.request.urlopen(urllib.parse.quote_plus(source, "\./_-:")) as response, open(destination_file, 'wb') as f:
            if hasattr(response, 'headers'):
                length = response.headers['Content-length']
            elif hasattr(response, 'getheader'):
                length = response.getheader('content-length')
            else:
                length = None
            if type(length) is str:
                with tqdm(total=int(length), unit='B', maxinterval=0.5) as pbar:
                    while True:
                        chunk = response.read(16384)
                        if not chunk:
                            break
                        f.write(chunk)
                        pbar.update(len(chunk))
            else:
                f.write(response.read())
    return destination_file

def calculate_sha256sum(filename):
    with open(filename,"rb") as f:
        return hashlib.sha256(f.read()).hexdigest();

def assert_sha256sum(file, sha256sum):
    val = calculate_sha256sum(file)
    if val == sha256sum:
        print('Verified checksum ' + str(file))
    else:
        print('The downloaded file ' + str(file) + ' could not be verified.')
        print('Actual   SHA256SUM: ' + val)
        print('Expected SHA256SUM: ' + sha256sum)
        print('Please remove the file and rerun the script.')
        sys.exit(1)

def download_files(imgurls, destination):
    destination_files = []
    for imgurl in imgurls:
        destination_files.append(download_file(imgurl, destination))
    return destination_files

def download_and_process(imgurls, destination, sha256sums = []):
    destination_files = download_files(imgurls, destination)
    for i, imgurl in enumerate(imgurls):
        if len(sha256sums) > 0:
            assert_sha256sum(destination_files[i], sha256sums[i])

def verify_noexistinginstall(destination):
    if os.listdir(destination) != []:
        print('There is a already an existing set of DOS installation files in ' + destination + '.')
        sys.exit(1)

def derive_url_list(imgurl):
    imgurls = []
    filename = urllib.parse.unquote(imgurl.split("/")[-1])
    if filename.casefold().find("disk") != -1:
        if filename.casefold().find("of") != -1:
            imgcnt = int(re.findall(r'(?i)disk\s?\d+\s?of\s(\d+)', filename)[0])
            for i in range(1, imgcnt+1):
                imgurls.append(imgurl.rsplit("/", 1)[0] + "/" + re.sub(r'(?i)(?P<one>disk\s?)\d+(?P<two>\s?of\s\d+)', r'\g<one>'+str(i) + r'\g<two>', filename))
        else: # if the filenames have just one number, assume the last disk from the set was provided
            imgcnt = int(re.findall(r'(?i)disk\s?(\d+)', filename)[0])
            for i in range(1, imgcnt+1):
                imgurls.append(imgurl.rsplit("/", 1)[0] + "/" + re.sub(r'(?i)(?P<one>disk\s?)\d+', r'\g<one>'+str(i), filename))
    else:
        imgurls.append(imgurl)
    return imgurls;

parser = argparse.ArgumentParser(description="""Script to download a set of disk images.
When a custom URL to a disk image is provided, based on known filename patterns, all related images are downloaded and extracted.
Either a name should contain \"Disk ? of x\" or \"diskx\" where 1 is the first disk and x is the last disk.
The destination directory will contain all files from all disk images.""")
parser.add_argument("-d", "--destination", help="Specify/override the destination directory", type=str)
parser.add_argument("-s", "--sha256sum", help="Specify one or more sha256sum(s)", nargs='+', type=str)
parser.add_argument("url", help="Specify an URL to download", type=str)

args = parser.parse_args()

if args.destination:
    url = args.url
    destination = args.destination
    os.makedirs(destination, exist_ok=True)
    if not os.listdir(destination):
        download_and_process(derive_url_list(url), destination, args.sha256sum if args.sha256sum else [])
    else:
        if args.sha256sum and not list(Path(destination).glob('*.sys')):
            for i, file in enumerate(Path(destination).iterdir()):
                assert_sha256sum(file, args.sha256sum[i])
        else:
            print('Continuing with existing file set in: ' + destination)
