diff --git a/api.py b/api.py new file mode 100644 index 0000000..d002dc4 --- /dev/null +++ b/api.py @@ -0,0 +1,37 @@ +from fastapi import FastAPI, File, UploadFile +from sqlalchemy import exists +import hashlib +import db + +def compute_hash(data: bytes, algorithm="sha256") -> str: + h = hashlib.new(algorithm) + h.update(data) + return h.hexdigest() + +app = FastAPI() + +@app.get("/") +def root(): + return {"message": "hiii from sfs"} + + +@app.post("/") +async def save_file(file: UploadFile = File(...)): + contents = await file.read() + + hash = compute_hash(contents) + + existed_url = db.file_exists(file.size, hash) + + if not existed_url: + file_url = db.add_file(file.filename, file.content_type, file.size, hash) + + with open(f"files/{file.filename}", "wb") as f: + f.write(contents) + return {"status": "saved", "filename": file_url} + else: + return {"status": "file_exists", "filename": existed_url} + +@app.get("/api/healthchecker") +def healthchecker(): + return {"message": "Howdy :3"} diff --git a/db.py b/db.py new file mode 100644 index 0000000..fad09f9 --- /dev/null +++ b/db.py @@ -0,0 +1,88 @@ +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy import select + +engine = create_engine("sqlite:///example.db") +session = Session(bind=engine) + +class Base(DeclarativeBase): pass + +class File(Base): + __tablename__ = "files" + + id = Column(Integer, primary_key=True, index=True) + extension = Column(String, nullable=True) + name = Column(String) + content_type = Column(String) + size = Column(Integer) + hash = Column(String) + + def __repr__(self): + return f"" + +def to_base36(n: int) -> str: + if n < 0: + raise ValueError("Only non-negative integers supported") + + chars = "0123456789abcdefghijklmnopqrstuvwxyz" + if n == 0: + return "0" + + result = [] + while n > 0: + n, rem = divmod(n, 36) + result.append(chars[rem]) + return "".join(reversed(result)) + +Base.metadata.create_all(bind=engine) + +def get_all_files(): + with Session(autoflush=False, bind=engine) as db: + statement = select(File) + return db.scalars(statement).all() + +def file_exists(size: int, hash_value: str) -> bool: + with Session(bind=engine) as db: + statement = select(File).where( + File.size == size, + File.hash == hash_value + ) + + existed_file = db.scalars(statement).first() + + if existed_file is None: + return None + + url = f"{to_base36(existed_file.id)}" + if existed_file.extension: + url += f".{existed_file.extension}" + + return url + + +def add_file(filename: str, content_type, size: int, hash): + with Session(autoflush=False, bind=engine) as db: + new_file = File() + if "." in filename: + new_file.name, new_file.extension = filename.rsplit(".", 1) + else: + new_file.name = filename + new_file.extension = None + new_file.content_type = content_type + new_file.size = size + new_file.hash = hash + db.add(new_file) + db.commit() + url = f"{to_base36(new_file.id)}" + if new_file.extension: + url += f".{new_file.extension}" + + return url + +if __name__ == "__main__": + for i in get_all_files(): + print(f"{i.id} {i.name}.{i.extension} ({i.hash}) {i.content_type} {i.size}") + +