diff --git a/routes/auth.py b/routes/auth.py index 145871e..7ef529b 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -1,9 +1,24 @@ -from fastapi import APIRouter -from service.auth_service import Credential +from fastapi import APIRouter, Response +from util.auth_lib import hash, gen_token +from service.auth_service import Credential, AuthService router = APIRouter() @router.post("/auth/login") -def login(auth: Credential): +def login(auth: Credential, resp: Response): + service = AuthService() + data = service.read(auth.username) - return {"ok": 1, "token": "Basic {}"} + hashed = hash(auth.password, data.salt) + if not data.username == auth.username and not data.password == hashed: + resp.status_code = 401 + return { + "ok": 0, + "errno": "Unauthorized" + } + + token = gen_token(auth.username, hashed) + return { + "ok": 1, + "token": "Basic {}".format(token) + } diff --git a/routes/balance.py b/routes/balance.py index cd16b49..1ba6d83 100644 --- a/routes/balance.py +++ b/routes/balance.py @@ -1,14 +1,26 @@ from datetime import datetime -from fastapi import APIRouter, Response +from fastapi import APIRouter, Response, Request +from service.auth_service import AuthService from service.balance_service import Balance, BalanceService, UpdateForm router = APIRouter() @router.post("/balance", status_code=201) -def insert(balance: Balance, resp: Response): +def insert(balance: Balance, req: Request, resp: Response): started = datetime.now().microsecond / 1000 + auth = AuthService() + + if not auth.check_auth(req): + resp.status_code = 403 + return { + "ok": 0, + "errno": "permission denied" + } + + info = auth.get_data(req) + service = BalanceService() - ok = service.create(balance=balance) + ok = service.create(info["username"], balance=balance) if not ok == 1: resp.status_code = 500 return { @@ -23,9 +35,44 @@ def insert(balance: Balance, resp: Response): "respond_time": "{}ms".format(round((datetime.now().microsecond / 1000) - started)) } -@router.get("/balance/{id}") -def query(id, resp: Response): +@router.get("/balance") +def query(req: Request, resp: Response): started = datetime.now().microsecond / 1000 + auth = AuthService() + if not auth.check_auth(req): + resp.status_code = 403 + return { + "ok": 0, + "errno": "permission denied" + } + + service = BalanceService() + data = service.query() + if data == None: + resp.status_code = 204 + return { + "ok": 0, + "errno": "no content" + } + + return { + "ok": 1, + "data": data, + "respond_time": "{}ms".format(round((datetime.now().microsecond / 1000) - started)) + } + +@router.get("/balance/{id}") +def find(id, req: Request, resp: Response): + started = datetime.now().microsecond / 1000 + auth = AuthService() + + if not auth.check_auth(req): + resp.status_code = 403 + return { + "ok": 0, + "errno": "permission denied" + } + service = BalanceService() data = service.read(int(id)) @@ -41,7 +88,18 @@ def query(id, resp: Response): } @router.patch("/balance/{action}/{id}") -def update(action, id, balance: UpdateForm, resp: Response): +def update(action, id, balance: UpdateForm, req: Request, resp: Response): + started = datetime.now().microsecond / 1000 + auth = AuthService() + + print(auth.check_auth(req)) + if not auth.check_auth(req): + resp.status_code = 403 + return { + "ok": 0, + "errno": "permission denied" + } + service = BalanceService() if action != "name" and action != "date" and action != "price" and action != "buy" and action != "memo": print(action) @@ -91,12 +149,22 @@ def update(action, id, balance: UpdateForm, resp: Response): return { "ok": 1, "id": int(id), - "action": action + "action": action, + "respond_time": "{}ms".format(round((datetime.now().microsecond / 1000) - started)) } @router.delete("/balance/{id}") -def delete(id, resp: Response): +def delete(id, req: Request, resp: Response): started = datetime.now().microsecond / 1000 + auth = AuthService() + + if not auth.check_auth(req): + resp.status_code = 403 + return { + "ok": 0, + "errno": "permission denied" + } + service = BalanceService() ok = service.delete(int(id)) if not ok == 1: diff --git a/service/auth_service.py b/service/auth_service.py index e31a09b..9b84a14 100644 --- a/service/auth_service.py +++ b/service/auth_service.py @@ -44,7 +44,7 @@ class AuthService: def read(self, username: str): cur = self._conn.cursor() - cur.execute("select * from account where username = %s;", (username)) + cur.execute("select * from account where username = %s;", (username, )) data = cur.fetchone() if data == None: return None @@ -58,16 +58,32 @@ class AuthService: password = data[2], salt = data[3] ) - - def check_auth(self, req: Request) -> bool: + + def get_data(self, req: Request): raw = req.headers.get("Authorization") + if raw == None: + return None + raw_token = raw.removeprefix("Basic ").encode("ascii") token = base64.b64decode(raw_token) data = token.decode("utf-8").split(":") + + return { + "username": data[0], + "password": data[1] + } - acc = self.read(data[0]) - if acc.username == data[0] and acc.password == data[1]: + def check_auth(self, req: Request) -> bool: + data = self.get_data(req) + if data == None: + return False + + acc = self.read(data["username"]) + if acc == None: + return False + + if acc.username == data["username"] and acc.password == data["password"]: return True return False diff --git a/service/balance_service.py b/service/balance_service.py index c67b82f..9ca9907 100644 --- a/service/balance_service.py +++ b/service/balance_service.py @@ -20,13 +20,13 @@ class BalanceService: def __init__(self): self._conn = psycopg2.connect(conn_param) - def create(self, balance: Balance): + def create(self, username: str, balance: Balance): ok = True cur = self._conn.cursor() try: cur.execute( - "insert into balset(name, date, price, buy, memo) values (%s, %s, %s, %s, %s);", - (balance.name, balance.date, balance.price, balance.buy, balance.memo) + "insert into balset(name, uid, date, price, buy, memo) values (%s, %s, %s, %s, %s, %s);", + (balance.name, username, balance.date, balance.price, balance.buy, balance.memo) ) self._conn.commit() @@ -39,6 +39,28 @@ class BalanceService: return ok + def query(self): + cur = self._conn.cursor() + cur.execute("select * from balset;") + + raw = cur.fetchall() + data = [] + + if len(raw) == 0: + return None + + for d in raw: + data.append({ + "id": d[0], + "name": d[1], + "date": d[2], + "price": d[3], + "buy": d[4], + "memo": d[5] + }) + + return data + def read(self, id: int): cur = self._conn.cursor() cur.execute("select * from balset where id = %s;", (id))