mirror of
https://github.com/dancojocaru2000/foxbank.git
synced 2025-02-22 23:39:36 +02:00
Fixed error handling
This commit is contained in:
parent
9ded9cc604
commit
31511f6004
6 changed files with 34 additions and 95 deletions
|
@ -4,8 +4,14 @@ from flask_smorest import Api
|
||||||
from .accounts import bp as acc_bp
|
from .accounts import bp as acc_bp
|
||||||
from .login import bp as login_bp
|
from .login import bp as login_bp
|
||||||
|
|
||||||
|
class ApiWithErr(Api):
|
||||||
|
def handle_http_exception(self, error):
|
||||||
|
if error.data and error.data['response']:
|
||||||
|
return error.data['response']
|
||||||
|
return super().handle_http_exception(error)
|
||||||
|
|
||||||
def init_apis(app: Flask):
|
def init_apis(app: Flask):
|
||||||
api = Api(app, spec_kwargs={
|
api = ApiWithErr(app, spec_kwargs={
|
||||||
'title': 'FoxBank',
|
'title': 'FoxBank',
|
||||||
'version': '1',
|
'version': '1',
|
||||||
'openapi_version': '3.0.0',
|
'openapi_version': '3.0.0',
|
||||||
|
|
|
@ -40,9 +40,9 @@ def get_valid_account_types():
|
||||||
def get_account_id(account_id: int):
|
def get_account_id(account_id: int):
|
||||||
account = db_utils.get_account(account_id=account_id)
|
account = db_utils.get_account(account_id=account_id)
|
||||||
if account is None:
|
if account is None:
|
||||||
return returns.NOT_FOUND
|
return returns.abort(returns.NOT_FOUND)
|
||||||
if decorators.user_id != db_utils.whose_account(account):
|
if decorators.user_id != db_utils.whose_account(account):
|
||||||
return returns.UNAUTHORIZED
|
return returns.abort(returns.UNAUTHORIZED)
|
||||||
account = account.to_json()
|
account = account.to_json()
|
||||||
return returns.success(account=account)
|
return returns.success(account=account)
|
||||||
|
|
||||||
|
@ -54,9 +54,9 @@ def get_account_id(account_id: int):
|
||||||
def get_account_iban(iban: str):
|
def get_account_iban(iban: str):
|
||||||
account = db_utils.get_account(iban=iban)
|
account = db_utils.get_account(iban=iban)
|
||||||
if account is None:
|
if account is None:
|
||||||
return returns.NOT_FOUND
|
return returns.abort(returns.NOT_FOUND)
|
||||||
if decorators.user_id != db_utils.whose_account(account):
|
if decorators.user_id != db_utils.whose_account(account):
|
||||||
return returns.UNAUTHORIZED
|
return returns.abort(returns.UNAUTHORIZED)
|
||||||
account = account.to_json()
|
account = account.to_json()
|
||||||
return returns.success(account=account)
|
return returns.success(account=account)
|
||||||
|
|
||||||
|
@ -80,9 +80,9 @@ class AccountsList(MethodView):
|
||||||
def post(self, currency: str, account_type: str, custom_name: str):
|
def post(self, currency: str, account_type: str, custom_name: str):
|
||||||
"""Create account"""
|
"""Create account"""
|
||||||
if currency not in VALID_CURRENCIES:
|
if currency not in VALID_CURRENCIES:
|
||||||
abort(HTTPStatus.UNPROCESSABLE_ENTITY)
|
return returns.abort(returns.invalid_argument('currency'))
|
||||||
if account_type not in ACCOUNT_TYPES:
|
if account_type not in ACCOUNT_TYPES:
|
||||||
abort(HTTPStatus.UNPROCESSABLE_ENTITY)
|
return returns.abort(returns.invalid_argument('account_type'))
|
||||||
|
|
||||||
account = Account(-1, '', currency, account_type, custom_name or '')
|
account = Account(-1, '', currency, account_type, custom_name or '')
|
||||||
db_utils.insert_account(decorators.user_id, account)
|
db_utils.insert_account(decorators.user_id, account)
|
||||||
|
|
|
@ -29,11 +29,11 @@ class Login(MethodView):
|
||||||
"""Login via username and TOTP code"""
|
"""Login via username and TOTP code"""
|
||||||
user: User | None = get_user(username=username)
|
user: User | None = get_user(username=username)
|
||||||
if user is None:
|
if user is None:
|
||||||
return returns.INVALID_DETAILS
|
return returns.abort(returns.INVALID_DETAILS)
|
||||||
|
|
||||||
otp = TOTP(user.otp)
|
otp = TOTP(user.otp)
|
||||||
if not otp.verify(code, valid_window=1):
|
if not otp.verify(code, valid_window=1):
|
||||||
return returns.INVALID_DETAILS
|
return returns.abort(returns.INVALID_DETAILS)
|
||||||
|
|
||||||
token = ram_db.login_user(user.id)
|
token = ram_db.login_user(user.id)
|
||||||
return returns.success(token=token)
|
return returns.success(token=token)
|
||||||
|
|
|
@ -49,13 +49,13 @@ class Module(ModuleType):
|
||||||
def wrapper(*args, **kargs):
|
def wrapper(*args, **kargs):
|
||||||
token = request.headers.get('Authorization', None)
|
token = request.headers.get('Authorization', None)
|
||||||
if token is None:
|
if token is None:
|
||||||
return returns.NO_AUTHORIZATION
|
return returns.abort(returns.NO_AUTHORIZATION)
|
||||||
if not token.startswith('Bearer '):
|
if not token.startswith('Bearer '):
|
||||||
return returns.INVALID_AUTHORIZATION
|
return returns.abort(returns.INVALID_AUTHORIZATION)
|
||||||
token = token[7:]
|
token = token[7:]
|
||||||
user_id = ram_db.get_user(token)
|
user_id = ram_db.get_user(token)
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
return returns.INVALID_AUTHORIZATION
|
return returns.abort(returns.INVALID_AUTHORIZATION)
|
||||||
|
|
||||||
global _token
|
global _token
|
||||||
_token = token
|
_token = token
|
||||||
|
@ -71,40 +71,4 @@ class Module(ModuleType):
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
# def ensure_logged_in(token=False, user_id=False):
|
|
||||||
# """
|
|
||||||
# Ensure the user is logged in by providing an Authorization: Bearer token
|
|
||||||
# header.
|
|
||||||
#
|
|
||||||
# @param token whether the token should be supplied after validation
|
|
||||||
# @param user_id whether the user_id should be supplied after validation
|
|
||||||
# @return decorator which supplies the requested parameters
|
|
||||||
# """
|
|
||||||
# def decorator(fn):
|
|
||||||
# pass_token = token
|
|
||||||
# pass_user_id = user_id
|
|
||||||
#
|
|
||||||
# @wraps(fn)
|
|
||||||
# def wrapper(*args, **kargs):
|
|
||||||
# token = request.headers.get('Authorization', None)
|
|
||||||
# if token is None:
|
|
||||||
# return returns.NO_AUTHORIZATION
|
|
||||||
# if not token.startswith('Bearer '):
|
|
||||||
# return returns.INVALID_AUTHORIZATION
|
|
||||||
# token = token[7:]
|
|
||||||
# user_id = ram_db.get_user(token)
|
|
||||||
# if user_id is None:
|
|
||||||
# return returns.INVALID_AUTHORIZATION
|
|
||||||
#
|
|
||||||
# if pass_user_id and pass_token:
|
|
||||||
# return fn(user_id=user_id, token=token, *args, **kargs)
|
|
||||||
# elif pass_user_id:
|
|
||||||
# return fn(user_id=user_id, *args, **kargs)
|
|
||||||
# elif pass_token:
|
|
||||||
# return fn(token=token, *args, **kargs)
|
|
||||||
# else:
|
|
||||||
# return fn(*args, **kargs)
|
|
||||||
# return wrapper
|
|
||||||
# return decorator
|
|
||||||
|
|
||||||
sys.modules[__name__] = Module(__name__)
|
sys.modules[__name__] = Module(__name__)
|
||||||
|
|
|
@ -31,6 +31,13 @@ NOT_FOUND = _make_error(
|
||||||
'general/not_found',
|
'general/not_found',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def invalid_argument(argname: str) -> tuple[Any, int]:
|
||||||
|
return _make_error(
|
||||||
|
_HTTPStatus.UNPROCESSABLE_ENTITY,
|
||||||
|
'general/invalid_argument',
|
||||||
|
message=f'Invalid argument: {argname}',
|
||||||
|
)
|
||||||
|
|
||||||
# Login
|
# Login
|
||||||
|
|
||||||
INVALID_DETAILS = _make_error(
|
INVALID_DETAILS = _make_error(
|
||||||
|
@ -76,3 +83,12 @@ class ErrorSchema(Schema):
|
||||||
|
|
||||||
class SuccessSchema(Schema):
|
class SuccessSchema(Schema):
|
||||||
status = fields.Constant('success')
|
status = fields.Constant('success')
|
||||||
|
|
||||||
|
# smorest
|
||||||
|
|
||||||
|
def abort(result: tuple[Any, int]):
|
||||||
|
try:
|
||||||
|
from flask_smorest import abort as _abort
|
||||||
|
_abort(result[1], response=result)
|
||||||
|
except ImportError:
|
||||||
|
return result
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
from functools import wraps
|
|
||||||
from flask import Blueprint, request
|
|
||||||
|
|
||||||
from pyotp import TOTP
|
|
||||||
|
|
||||||
import db_utils
|
|
||||||
from decorators import no_content, ensure_logged_in, user_id, token
|
|
||||||
import models
|
|
||||||
import ram_db
|
|
||||||
import returns
|
|
||||||
|
|
||||||
login = Blueprint('login', __name__)
|
|
||||||
|
|
||||||
@login.post('/')
|
|
||||||
def make_login():
|
|
||||||
try:
|
|
||||||
username = request.json['username']
|
|
||||||
code = request.json['code']
|
|
||||||
except (TypeError, KeyError):
|
|
||||||
return returns.INVALID_REQUEST
|
|
||||||
|
|
||||||
user: models.User | None = db_utils.get_user(username=username)
|
|
||||||
if user is None:
|
|
||||||
return returns.INVALID_DETAILS
|
|
||||||
|
|
||||||
otp = TOTP(user.otp)
|
|
||||||
if not otp.verify(code, valid_window=1):
|
|
||||||
return returns.INVALID_DETAILS
|
|
||||||
|
|
||||||
token = ram_db.login_user(user.id)
|
|
||||||
return returns.success(token=token)
|
|
||||||
|
|
||||||
@login.post('/logout')
|
|
||||||
@ensure_logged_in
|
|
||||||
@no_content
|
|
||||||
def logout():
|
|
||||||
ram_db.logout_user(token)
|
|
||||||
|
|
||||||
@login.get('/whoami')
|
|
||||||
@ensure_logged_in
|
|
||||||
def whoami():
|
|
||||||
user: models.User | None = db_utils.get_user(user_id=user_id)
|
|
||||||
if user is not None:
|
|
||||||
user = user.to_json()
|
|
||||||
|
|
||||||
return returns.successs(user=user)
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue