Compare commits

...

4 Commits

View File

@ -8,7 +8,7 @@ from collections import defaultdict
from Cryptodome.Cipher import AES from Cryptodome.Cipher import AES
from Cryptodome.Util import Padding from Cryptodome.Util import Padding
from aiohttp import web from aiohttp import web, hdrs
class AttrDict(dict): class AttrDict(dict):
@ -23,17 +23,17 @@ class AttrDict(dict):
def __getattr__(self, item): def __getattr__(self, item):
return self.setdefault(item, AttrDict()) return self.setdefault(item, AttrDict())
@staticmethod
def cast_to_ad(d): def from_dict_recur(d):
if not isinstance(d, AttrDict): if not isinstance(d, AttrDict):
d = AttrDict(d) d = AttrDict(d)
for k, v in dict(d.items()).items(): for k, v in dict(d.items()).items():
if " " in k: if " " in k:
del d[k] del d[k]
d[k.replace(" ", "_")] = v d[k.replace(" ", "_")] = v
if isinstance(v, dict): if isinstance(v, dict):
d[k] = cast_to_ad(v) d[k] = AttrDict.from_dict_recur(v)
return d return d
def sizeof_fmt(num, suffix='B'): def sizeof_fmt(num, suffix='B'):
@ -47,7 +47,7 @@ def sizeof_fmt(num, suffix='B'):
async def prepare(_, handler): async def prepare(_, handler):
async def prepare_handler(req): async def prepare_handler(req):
if 'acc' not in req.match_info: if 'acc' not in req.match_info:
return web.Response(text='internal server error', status=500) return web.Response(text='bad request', status=400)
return await handler(req, req.match_info["acc"], file_db[req.match_info["acc"]]) return await handler(req, req.match_info["acc"], file_db[req.match_info["acc"]])
return prepare_handler return prepare_handler
@ -126,8 +126,10 @@ async def handle_download(req, acc, acc_db):
fhash = req.match_info.get('hash', '').split('.', 1)[0] fhash = req.match_info.get('hash', '').split('.', 1)[0]
if fhash not in acc_db: if fhash not in acc_db:
return web.Response(text='file not found', status=404) return web.Response(text='file not found', status=404)
return web.FileResponse(f"{conf.data_path}/{acc}/{fhash}_{acc_db[fhash]}", headers={ return web.FileResponse(f"{conf.data_path}/{acc}/{fhash}_{acc_db[fhash]}", headers={
'CONTENT-DISPOSITION': f'inline;filename="{acc_db[fhash]}"' hdrs.CACHE_CONTROL: "no-cache",
hdrs.CONTENT_DISPOSITION: f'inline;filename="{acc_db[fhash]}"'
}) })
@ -138,8 +140,9 @@ def main():
if not os.path.isdir(f'{conf.data_path}/{acc}'): if not os.path.isdir(f'{conf.data_path}/{acc}'):
continue continue
for file in os.listdir(f"{conf.data_path}/{acc}"): for file in os.listdir(f"{conf.data_path}/{acc}"):
fhash, fname = file.split('_', 1) if "_" in file:
file_db[acc][fhash] = fname fhash, fname = file.split('_', 1)
file_db[acc][fhash] = fname
app = web.Application(middlewares=[prepare]) app = web.Application(middlewares=[prepare])
app.router.add_post(conf.prefix + '/post/{acc}', handle_upload) app.router.add_post(conf.prefix + '/post/{acc}', handle_upload)
@ -154,7 +157,7 @@ if __name__ == '__main__':
file_db = defaultdict(dict) file_db = defaultdict(dict)
confname = sys.argv[1] if sys.argv[1:] and os.path.isfile(sys.argv[1]) else 'config.yaml' confname = sys.argv[1] if sys.argv[1:] and os.path.isfile(sys.argv[1]) else 'config.yaml'
with open(confname) as cf: with open(confname) as cf:
conf = cast_to_ad(yaml.load(cf)) conf = AttrDict.from_dict_recur(yaml.load(cf))
if conf.url_hash_len > 31: if conf.url_hash_len > 31:
raise ValueError('url_hash_len can\'t be bigger than 31') raise ValueError('url_hash_len can\'t be bigger than 31')
if not set(conf.max_filesize.replace(' ', ''))\ if not set(conf.max_filesize.replace(' ', ''))\