diff --git a/README.md b/README.md index d52e1f0..1b0ba8c 100644 --- a/README.md +++ b/README.md @@ -39,12 +39,7 @@ DB_USERNAME=user DB_PASSWORD=sample1234! ``` -5. generate.py를 실행하여 테이블 및 계정을 생성 해줍니다. -```bash -python generate.py -``` - -6. fastapi 명령어를 이용하여 서비스를 실행 해줍니다. +5. fastapi 명령어를 이용하여 서비스를 실행 해줍니다. ```bash fastapi run app.py ``` @@ -52,3 +47,6 @@ fastapi run app.py ```bash fastapi run app.py --port 3000 ``` + +## License +본 프로젝트는 (MIT License)[https://github.com/devproje/balance-application/blob/master/LICENSE]를 따릅니다. diff --git a/app.py b/app.py index 1bd4397..246eed5 100644 --- a/app.py +++ b/app.py @@ -1,9 +1,29 @@ +import psycopg2 +from generate import on_load from fastapi import FastAPI, Response from routes.auth import router as auth +from contextlib import asynccontextmanager +from util.config import conn_param, db_url from routes.balance import router as balance from fastapi.middleware.cors import CORSMiddleware -app = FastAPI() +@asynccontextmanager +async def lifespan(app: FastAPI): + conn = psycopg2.connect(conn_param) + cur = conn.cursor() + + try: + print("loading database for: %s" % db_url()) + on_load(conn, cur) + except: + print("[warn] error occurred while creating table. aborted") + finally: + cur.close() + conn.close() + + yield + +app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, @@ -14,7 +34,7 @@ app.add_middleware( ) @app.get("/") -def index(resp: Response): +async def index(resp: Response): resp.headers.setdefault("Content-Type", "text") return "Hello, World!" diff --git a/generate.py b/generate.py index b35ccfd..2441220 100644 --- a/generate.py +++ b/generate.py @@ -1,94 +1,61 @@ -import psycopg2 import random, string from getpass import getpass from util.auth_lib import hash -from util.config import conn_param from service.auth_service import AuthData, AuthService def gen_salt(length = 20): letters = string.ascii_lowercase + string.digits + string.punctuation return "".join(random.choice(letters) for i in range(length)) -def _gen_token(): - deps = string.ascii_lowercase + string.ascii_uppercase + string.digits + string.punctuation - token = "".join(random.choice(deps) for i in range(20)) +def _new_account(): + name = input("input your display name: ") + username = input("input your username: ") + password = getpass("input your password: ") + passchk = getpass("type password one more time: ") + salt = gen_salt() - sec = open("./secret_token.txt", "w") - sec.write(token) - - sec.close() - -def __main__(): - conn = psycopg2.connect(conn_param) - cur = conn.cursor() - - try: - f = open("./load.txt", "r") - _gen_token() - if f.read().split("=")[1] == "false": - raise ValueError("value not true") - - print("server already initialized") - f.close() - except: - cur.execute( - """ - create table if not exists account( - name varchar(25), - username varchar(25) not null, - password varchar(100) not null, - salt varchar(50), - primary key(username) - ); - """ - ) - - cur.execute( - """ - create table if not exists balset( - id serial primary key, - uid varchar(25) not null, - name varchar(50), - date bigint, - price bigint, - buy boolean, - memo varchar(300), - constraint FK_Account_ID - foreign key (uid) - references account(username) - on delete CASCADE - ); - """ - ) - - conn.commit() - - cur.close() - conn.close() - - name = input("input your display name: ") - username = input("input your username: ") - password = getpass("input your password: ") - passchk = getpass("type password one more time: ") - salt = gen_salt() - - if password != passchk: - return + if password != passchk: + return - hashed_password = hash(password, salt) - packed = AuthData( - name=name, - username=username, - password=hashed_password, - salt=salt - ) + hashed_password = hash(password, salt) + packed = AuthData( + name=name, + username=username, + password=hashed_password, + salt=salt + ) - service = AuthService() - service.create(data=packed) + service = AuthService() + service.create(data=packed) - f = open("load.txt", "w") - f.write("init=true") - - f.close() - -__main__() +def on_load(conn, cur): + cur.execute( + """ + create table account( + name varchar(25), + username varchar(25) not null, + password varchar(100) not null, + salt varchar(50), + primary key(username) + ); + """ + ) + cur.execute( + """ + create table balset( + id serial primary key, + uid varchar(25) not null, + name varchar(50), + date bigint, + price bigint, + buy boolean, + memo varchar(300), + constraint FK_Account_ID + foreign key (uid) + references account(username) + on delete CASCADE + ); + """ + ) + conn.commit() + _new_account() diff --git a/routes/balance.py b/routes/balance.py index 1ba6d83..3de79ab 100644 --- a/routes/balance.py +++ b/routes/balance.py @@ -87,12 +87,10 @@ def find(id, req: Request, resp: Response): "respond_time": "{}ms".format(round((datetime.now().microsecond / 1000) - started)) } -@router.patch("/balance/{action}/{id}") -def update(action, id, balance: UpdateForm, req: Request, resp: Response): +@router.put("/balance/{id}") +def update(id, balance: UpdateForm, req: Request, resp: Response): started = datetime.now().microsecond / 1000 auth = AuthService() - - print(auth.check_auth(req)) if not auth.check_auth(req): resp.status_code = 403 return { @@ -101,43 +99,7 @@ def update(action, id, balance: UpdateForm, req: Request, resp: Response): } service = BalanceService() - if action != "name" and action != "date" and action != "price" and action != "buy" and action != "memo": - print(action) - print(id) - resp.status_code = 400 - return {"ok": 0, "errno": "action must be to name, date, price or memo"} - - if action == "name" and balance.name == "": - resp.status_code = 400 - return {"ok": 0, "action": action, "errno": "name value cannot be empty"} - - if action == "date" and balance.date <= 0: - resp.status_code = 400 - return {"ok": 0, "action": action, "errno": "date value cannot be 0 or minus"} - - if action == "price" and balance.price <= 0: - resp.status_code = 400 - return {"ok": 0, "action": action, "errno": "price value cannot be 0 or minus"} - - if action == "memo" and len(balance.memo) > 300: - resp.status_code = 400 - return { - "ok": 0, - "action": action, - "errno": "memo value size is too long: (maximum size: 300 bytes, your size: {} bytes)".format(len(balance.memo)) - } - - ok = service.update( - int(id), - action, - { - "name": balance.name, - "date": balance.date, - "price": balance.price, - "buy": balance.buy, - "memo": balance.memo - } - ) + ok = service.update(int(id), balance) if not ok == 1: resp.status_code = 500 @@ -149,7 +111,6 @@ def update(action, id, balance: UpdateForm, req: Request, resp: Response): return { "ok": 1, "id": int(id), - "action": action, "respond_time": "{}ms".format(round((datetime.now().microsecond / 1000) - started)) } diff --git a/service/balance_service.py b/service/balance_service.py index 8a64be1..c45d6f3 100644 --- a/service/balance_service.py +++ b/service/balance_service.py @@ -82,11 +82,18 @@ class BalanceService: "memo": data[5] } - def update(self, id: int, act: str, balance: UpdateForm): + def update(self, id: int, balance: UpdateForm): ok = True cur = self._conn.cursor() try: - cur.execute(f"update balset set {act} = %s where id = %s;", (balance[act], id)) + cur.execute("update balset set name = %s, date = %s, price = %s, buy = %s, memo = %s where id = %s;", ( + balance.name, + balance.date, + balance.price, + balance.buy, + balance.memo, + id + )) self._conn.commit() except: self._conn.rollback() @@ -101,7 +108,7 @@ class BalanceService: ok = True cur = self._conn.cursor() try: - cur.execute("delete from balset where id = %s;", (id)) + cur.execute("delete from balset where id = %s;", (str(id))) self._conn.commit() except: self._conn.rollback() diff --git a/util/config.py b/util/config.py index 19f0e5f..e88ba8d 100644 --- a/util/config.py +++ b/util/config.py @@ -1,7 +1,7 @@ import os from dotenv import load_dotenv -load_dotenv() +load_dotenv(verbose=True, override=True) def _load_secret(): try: @@ -11,6 +11,9 @@ def _load_secret(): return tok +def db_url(): + return os.getenv("DB_URL") + conn_param = "host=%s port=%s dbname=%s user=%s password=%s" % ( os.getenv("DB_URL"), os.getenv("DB_PORT"),