import os
import yaml
import string
import hashlib
import argparse
from collections import defaultdict
from urllib.parse import urlparse

from Cryptodome.Cipher import AES
from Cryptodome.Util import Padding
from aiohttp import web, hdrs

from thumbnail import generate_thumbnail

ThumbnailOptions = {
    'trim': False,
    'height': 300,
    'width': 300,
    'quality': 85,
    'type': 'thumbnail'
}


class AppConfig:
    _instance = None

    def __new__(cls, *args, **kwargs):
        if not cls._instance:
            cls._instance = super(AppConfig, cls).__new__(cls)
        return cls._instance

    def __init__(self, config_file):
        if not hasattr(self, 'is_loaded'):
            self.load_config(config_file)
            self.is_loaded = True

    def load_config(self, config_file):
        with open(config_file) as f:
            data = yaml.safe_load(f)

        self.url_hash_len = data.get('url_hash_len', 8)
        self.data_path = data.get('data_path', '/data')
        self.auth = data.get('auth', True)
        self.tokens = set(data.get('tokens', []))
        self.del_crypt_key = data.get('del_crypt_key', 'default_key')
        self.show_ext = data.get('show_ext', True)
        self.max_filesize = data.get('max_filesize', '1024 ** 2 * 100')  # Default 100 MB
        self.base_url = data.get('base_url', 'http://localhost/f/')
        self.frontend = data.get('frontend', False)
        self.thumbnails = data.get('thumbnails', False)
        self.thumbnail_path = data.get('thumbnail_path', '/data/thumbs')
        self.thumbnail_strategy = data.get('thumbnail_strategy', 'both')

        self.validate_config()

    def validate_config(self):
        if self.url_hash_len > 32:
            raise ValueError('url_hash_len cannot be greater than 32')
        if self.url_hash_len % 2 != 0:
            raise ValueError('url_hash_len must be a multiple of 2')

        self.max_filesize = self.evaluate_filesize(self.max_filesize)

        self.del_crypt_key = hashlib.md5(self.del_crypt_key.encode()).digest()[:16]

        if not os.path.isdir(self.data_path):
            os.mkdir(self.data_path)

        if not os.path.isdir(self.thumbnail_path):
            os.mkdir(self.thumbnail_path)

        self.base_url = f"{self.base_url.strip("/ \t\r\n")}"

    @staticmethod
    def evaluate_filesize(size_str):
        valid_chars = set(string.hexdigits + '* ')
        if not set(size_str).issubset(valid_chars):
            raise ValueError('Invalid characters in max_filesize')

        try:
            return eval(size_str)
        except Exception:
            raise ValueError('Invalid format for max_filesize')


def sizeof_fmt(num, suffix='B'):
    for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
        if abs(num) < 1024.:
            return f"{num:3.1f}{unit}{suffix}"
        num /= 1024.
    return f"{num:.1f}Yi{suffix}"


async def handle_upload(req):
    if conf.auth:
        auth_header = req.headers.get(hdrs.AUTHORIZATION, None)

        if auth_header is None:
            return web.Response(text='Authentication required', status=401)

        try:
            scheme, token = auth_header.split(' ')
            if scheme.lower() != 'bearer':
                raise ValueError
        except ValueError:
            return web.Response(text='Invalid authentication scheme', status=401)

        if token not in conf.tokens:
            return web.Response(text='Access denied', status=403)

    reader = await req.multipart()
    file = await reader.next()

    filename = os.path.basename(file.filename)

    if not os.path.isdir(f'{conf.data_path}'):
        os.mkdir(f'{conf.data_path}')

    for _ in range(100):
        hb = os.urandom(conf.url_hash_len // 2)
        h = hb.hex()
        if h not in file_db:
            break
    else:
        return web.Response(text='url key-space full', status=500)

    file_db[h] = filename
    local_fname = f'{conf.data_path}/{h}_{filename}'
    ext = os.path.splitext(filename)[1] if conf.show_ext else ''
    os.fdopen(os.open(local_fname, os.O_WRONLY | os.O_CREAT, 0o600)).close()
    try:
        valid_file = await recv_file(file, local_fname)
    except IOError:
        return web.Response(text='internal io error', status=500)
    if valid_file:
        c = AES.new(conf.del_crypt_key, AES.MODE_CBC)
        hb = Padding.pad(hb, AES.block_size)
        del_h = (c.encrypt(hb) + c.iv).hex()
        local_thname = f'{conf.thumbnail_path}/{h}_{filename}.png'
        generate_thumbnail(local_fname, local_thname, ThumbnailOptions)

        return web.Response(text=f'{{"file_link":"{conf.base_url}/{h}{ext}",'
                                 f'"thumb_link":"{conf.base_url}/thumb/{h}{ext}",'
                                 f'"delete_link":"{conf.base_url}/del/{del_h}"}}', status=200)

    os.unlink(local_fname)
    del file_db[h]
    return web.Response(text=f'file is bigger than {sizeof_fmt(conf.max_filesize)}', status=413)


async def recv_file(file, local_fname):
    size = 0
    with open(local_fname, 'wb') as f:
        while True:
            chunk = await file.read_chunk()
            if not chunk:
                return True
            size += len(chunk)
            if size > conf.max_filesize:
                return False
            f.write(chunk)


async def handle_delete(req):
    chashiv = req.match_info.get('hash', 'x')
    if len(chashiv) != 64 or not set(chashiv).issubset(hexdigits_set):
        return web.Response(text='invalid delete link', status=400)
    chashiv = bytes.fromhex(chashiv)

    c = AES.new(conf.del_crypt_key, AES.MODE_CBC, iv=chashiv[AES.block_size:])
    fhash = c.decrypt(chashiv[:AES.block_size])
    try:
        fhash = Padding.unpad(fhash, AES.block_size).hex()
    except ValueError:
        pass
    if fhash not in file_db:
        return web.Response(text='this file doesn\'t exist on the server', status=404)
    os.unlink(f"{conf.data_path}/{fhash}_{file_db[fhash]}")
    del file_db[fhash]
    return web.Response(text='file deleted')


async def handle_download(req):
    fhash = req.match_info.get('hash', '').split('.', 1)[0]
    if fhash not in file_db:
        return web.Response(text='file not found', status=404)

    return web.FileResponse(f"{conf.data_path}/{fhash}_{file_db[fhash]}", headers={
        hdrs.CONTENT_DISPOSITION: f'inline;filename="{file_db[fhash]}"'
    })


async def handle_thumbnail(req):
    fhash = req.match_info.get('hash', '').split('.', 1)[0]
    if fhash not in file_db:
        return web.Response(text='file not found', status=404)

    ## TODO: If thumbnail doesn't exist, generate new thumbnail

    return web.FileResponse(f"{conf.thumbnail_path}/{fhash}_{file_db[fhash]}", headers={
        hdrs.CONTENT_DISPOSITION: f'inline;filename="{file_db[fhash]}"'
    })


def main():
    for file in os.listdir(f"{conf.data_path}"):
        try:
            fhash, fname = file.split('_', 1)
        except ValueError:
            print(f"file \"{file}\" has an invalid file name format, skipping...")
            continue
        file_db[fhash] = fname

    parsed_url = urlparse(conf.base_url)
    base_path = parsed_url.path

    app = web.Application()
    app.router.add_post(base_path + '/post', handle_upload)
    app.router.add_get(base_path + '/del/{hash}', handle_delete)
    app.router.add_get(base_path + '/{hash}', handle_download)
    app.router.add_get(base_path + '/thumb/{hash}', handle_thumbnail)

    web.run_app(app, port=80)


def parse_args():
    parser = argparse.ArgumentParser(description="File serving and uploading server intended for use as a ShareX host.")

    parser.add_argument('-c', '--config', default=None,
                        help='Path to the configuration file.')

    parser.add_argument('config_file', nargs='?', default='config.yaml',
                        help='Path to the configuration file (positional argument).')

    args = parser.parse_args()

    if args.config and args.config_file != 'config.yaml':
        print("Warning: Both positional and optional config arguments provided. Using the -c argument.")
        return args.config
    return args.config or args.config_file


if __name__ == '__main__':
    hexdigits_set = set(string.hexdigits)
    file_db = defaultdict(dict)
    conf_name = parse_args()
    print("Loading config file", conf_name)
    conf = AppConfig(conf_name)
    main()