136 lines
3.6 KiB
Python
136 lines
3.6 KiB
Python
from sqlalchemy import Column, Integer, String
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
from sqlalchemy import select
|
|
|
|
from dotenv import load_dotenv
|
|
import os
|
|
load_dotenv()
|
|
PADDING = int(os.getenv("FILES_PADDING"))
|
|
DATABASE_NAME = os.getenv("DATABASE_NAME")
|
|
|
|
engine = create_engine(f"sqlite:///{DATABASE_NAME}")
|
|
session = Session(bind=engine)
|
|
|
|
class Base(DeclarativeBase): pass
|
|
|
|
class File(Base):
|
|
__tablename__ = "files"
|
|
|
|
id = Column(Integer, primary_key=True, index=True)
|
|
extension = Column(String, nullable=True)
|
|
name = Column(String)
|
|
content_type = Column(String)
|
|
size = Column(Integer)
|
|
hash = Column(String)
|
|
|
|
def __repr__(self):
|
|
return f"<File(id={self.id}, extension='{self.extension}', hash='{self.hash}')>"
|
|
|
|
def to_base36(n: int, width: int) -> str:
|
|
if n < 0:
|
|
raise ValueError("Only non-negative integers supported")
|
|
|
|
chars = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|
if n == 0:
|
|
return "0".rjust(width, "0")
|
|
|
|
result = []
|
|
while n > 0:
|
|
n, rem = divmod(n, 36)
|
|
result.append(chars[rem])
|
|
return "".join(reversed(result)).rjust(width, "0")
|
|
|
|
def get_url_from_id(id: int, extension):
|
|
url = f"{to_base36(id, PADDING)}"
|
|
if extension:
|
|
url += f".{extension}"
|
|
return url
|
|
|
|
def from_base36(s: str) -> int:
|
|
chars = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|
char_to_val = {c: i for i, c in enumerate(chars)}
|
|
|
|
s = s.lower().lstrip("0")
|
|
if not s:
|
|
return 0
|
|
|
|
n = 0
|
|
for ch in s:
|
|
if ch not in char_to_val:
|
|
raise ValueError(f"Invalid base36 character: {ch}")
|
|
n = n * 36 + char_to_val[ch]
|
|
return n
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
def get_all_files():
|
|
with Session(autoflush=False, bind=engine) as db:
|
|
statement = select(File)
|
|
files = db.scalars(statement).all()
|
|
print(files)
|
|
print(type(files))
|
|
print(type(files[0]))
|
|
return [
|
|
{
|
|
"url": get_url_from_id(f.id, f.extension),
|
|
"name": f.name,
|
|
"content_type": f.content_type,
|
|
"size": f.size,
|
|
}
|
|
for f in files
|
|
]
|
|
|
|
|
|
|
|
def file_exists(size: int, hash_value: str) -> bool:
|
|
with Session(bind=engine) as db:
|
|
statement = select(File).where(
|
|
File.size == size,
|
|
File.hash == hash_value
|
|
)
|
|
|
|
existed_file = db.scalars(statement).first()
|
|
|
|
if existed_file is None:
|
|
return None
|
|
|
|
url = get_url_from_id(existed_file.id, existed_file.extension)
|
|
|
|
return url
|
|
|
|
|
|
def add_file(filename: str, content_type, size: int, hash):
|
|
with Session(autoflush=False, bind=engine) as db:
|
|
new_file = File()
|
|
if "." in filename:
|
|
new_file.name, new_file.extension = filename.rsplit(".", 1)
|
|
else:
|
|
new_file.name = filename
|
|
new_file.extension = None
|
|
new_file.content_type = content_type
|
|
new_file.size = size
|
|
new_file.hash = hash
|
|
db.add(new_file)
|
|
db.commit()
|
|
url = get_url_from_id(new_file.id, new_file.extension)
|
|
return url
|
|
|
|
|
|
def remove_file(file_url: str):
|
|
with Session(autoflush=False, bind=engine) as db:
|
|
file_id = from_base36(file_url.rsplit(".")[0])
|
|
file = db.get(File, file_id)
|
|
if not file:
|
|
return None
|
|
db.delete(file)
|
|
db.commit()
|
|
return True
|
|
|
|
if __name__ == "__main__":
|
|
for i in get_all_files():
|
|
print(f"{i.id} {i.name}.{i.extension} ({i.hash}) {i.content_type} {i.size}")
|
|
|
|
|