mirror of
				https://github.com/dancojocaru2000/foxbank.git
				synced 2025-11-04 03:46:31 +02:00 
			
		
		
		
	Fixed error handling
This commit is contained in:
		
							parent
							
								
									9ded9cc604
								
							
						
					
					
						commit
						31511f6004
					
				
					 6 changed files with 34 additions and 95 deletions
				
			
		| 
						 | 
					@ -4,8 +4,14 @@ from flask_smorest import Api
 | 
				
			||||||
from .accounts import bp as acc_bp
 | 
					from .accounts import bp as acc_bp
 | 
				
			||||||
from .login import bp as login_bp
 | 
					from .login import bp as login_bp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ApiWithErr(Api):
 | 
				
			||||||
 | 
					    def handle_http_exception(self, error):
 | 
				
			||||||
 | 
					        if error.data and error.data['response']:
 | 
				
			||||||
 | 
					            return error.data['response']
 | 
				
			||||||
 | 
					        return super().handle_http_exception(error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def init_apis(app: Flask):
 | 
					def init_apis(app: Flask):
 | 
				
			||||||
    api = Api(app, spec_kwargs={
 | 
					    api = ApiWithErr(app, spec_kwargs={
 | 
				
			||||||
        'title': 'FoxBank',
 | 
					        'title': 'FoxBank',
 | 
				
			||||||
        'version': '1',
 | 
					        'version': '1',
 | 
				
			||||||
        'openapi_version': '3.0.0',
 | 
					        'openapi_version': '3.0.0',
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,9 +40,9 @@ def get_valid_account_types():
 | 
				
			||||||
def get_account_id(account_id: int):
 | 
					def get_account_id(account_id: int):
 | 
				
			||||||
    account = db_utils.get_account(account_id=account_id)
 | 
					    account = db_utils.get_account(account_id=account_id)
 | 
				
			||||||
    if account is None:
 | 
					    if account is None:
 | 
				
			||||||
        return returns.NOT_FOUND
 | 
					        return returns.abort(returns.NOT_FOUND)
 | 
				
			||||||
    if decorators.user_id != db_utils.whose_account(account):
 | 
					    if decorators.user_id != db_utils.whose_account(account):
 | 
				
			||||||
        return returns.UNAUTHORIZED
 | 
					        return returns.abort(returns.UNAUTHORIZED)
 | 
				
			||||||
    account = account.to_json()
 | 
					    account = account.to_json()
 | 
				
			||||||
    return returns.success(account=account)
 | 
					    return returns.success(account=account)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -54,9 +54,9 @@ def get_account_id(account_id: int):
 | 
				
			||||||
def get_account_iban(iban: str):
 | 
					def get_account_iban(iban: str):
 | 
				
			||||||
    account = db_utils.get_account(iban=iban)
 | 
					    account = db_utils.get_account(iban=iban)
 | 
				
			||||||
    if account is None:
 | 
					    if account is None:
 | 
				
			||||||
        return returns.NOT_FOUND
 | 
					        return returns.abort(returns.NOT_FOUND)
 | 
				
			||||||
    if decorators.user_id != db_utils.whose_account(account):
 | 
					    if decorators.user_id != db_utils.whose_account(account):
 | 
				
			||||||
        return returns.UNAUTHORIZED
 | 
					        return returns.abort(returns.UNAUTHORIZED)
 | 
				
			||||||
    account = account.to_json()
 | 
					    account = account.to_json()
 | 
				
			||||||
    return returns.success(account=account)
 | 
					    return returns.success(account=account)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -80,9 +80,9 @@ class AccountsList(MethodView):
 | 
				
			||||||
    def post(self, currency: str, account_type: str, custom_name: str):
 | 
					    def post(self, currency: str, account_type: str, custom_name: str):
 | 
				
			||||||
        """Create account"""
 | 
					        """Create account"""
 | 
				
			||||||
        if currency not in VALID_CURRENCIES:
 | 
					        if currency not in VALID_CURRENCIES:
 | 
				
			||||||
            abort(HTTPStatus.UNPROCESSABLE_ENTITY)
 | 
					            return returns.abort(returns.invalid_argument('currency'))
 | 
				
			||||||
        if account_type not in ACCOUNT_TYPES:
 | 
					        if account_type not in ACCOUNT_TYPES:
 | 
				
			||||||
            abort(HTTPStatus.UNPROCESSABLE_ENTITY)
 | 
					            return returns.abort(returns.invalid_argument('account_type'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        account = Account(-1, '', currency, account_type, custom_name or '')
 | 
					        account = Account(-1, '', currency, account_type, custom_name or '')
 | 
				
			||||||
        db_utils.insert_account(decorators.user_id, account)
 | 
					        db_utils.insert_account(decorators.user_id, account)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,11 +29,11 @@ class Login(MethodView):
 | 
				
			||||||
        """Login via username and TOTP code"""
 | 
					        """Login via username and TOTP code"""
 | 
				
			||||||
        user: User | None = get_user(username=username)
 | 
					        user: User | None = get_user(username=username)
 | 
				
			||||||
        if user is None:
 | 
					        if user is None:
 | 
				
			||||||
            return returns.INVALID_DETAILS
 | 
					            return returns.abort(returns.INVALID_DETAILS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        otp = TOTP(user.otp)
 | 
					        otp = TOTP(user.otp)
 | 
				
			||||||
        if not otp.verify(code, valid_window=1):
 | 
					        if not otp.verify(code, valid_window=1):
 | 
				
			||||||
            return returns.INVALID_DETAILS
 | 
					            return returns.abort(returns.INVALID_DETAILS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        token = ram_db.login_user(user.id)
 | 
					        token = ram_db.login_user(user.id)
 | 
				
			||||||
        return returns.success(token=token)
 | 
					        return returns.success(token=token)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,13 +49,13 @@ class Module(ModuleType):
 | 
				
			||||||
        def wrapper(*args, **kargs):
 | 
					        def wrapper(*args, **kargs):
 | 
				
			||||||
            token = request.headers.get('Authorization', None)
 | 
					            token = request.headers.get('Authorization', None)
 | 
				
			||||||
            if token is None:
 | 
					            if token is None:
 | 
				
			||||||
                return returns.NO_AUTHORIZATION
 | 
					                return returns.abort(returns.NO_AUTHORIZATION)
 | 
				
			||||||
            if not token.startswith('Bearer '):
 | 
					            if not token.startswith('Bearer '):
 | 
				
			||||||
                return returns.INVALID_AUTHORIZATION
 | 
					                return returns.abort(returns.INVALID_AUTHORIZATION)
 | 
				
			||||||
            token = token[7:]
 | 
					            token = token[7:]
 | 
				
			||||||
            user_id = ram_db.get_user(token)
 | 
					            user_id = ram_db.get_user(token)
 | 
				
			||||||
            if user_id is None:
 | 
					            if user_id is None:
 | 
				
			||||||
                return returns.INVALID_AUTHORIZATION
 | 
					                return returns.abort(returns.INVALID_AUTHORIZATION)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            global _token
 | 
					            global _token
 | 
				
			||||||
            _token = token
 | 
					            _token = token
 | 
				
			||||||
| 
						 | 
					@ -71,40 +71,4 @@ class Module(ModuleType):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return wrapper
 | 
					        return wrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # def ensure_logged_in(token=False, user_id=False):
 | 
					 | 
				
			||||||
    #     """
 | 
					 | 
				
			||||||
    #     Ensure the user is logged in by providing an Authorization: Bearer token
 | 
					 | 
				
			||||||
    #     header.
 | 
					 | 
				
			||||||
    #
 | 
					 | 
				
			||||||
    #     @param token whether the token should be supplied after validation
 | 
					 | 
				
			||||||
    #     @param user_id whether the user_id should be supplied after validation
 | 
					 | 
				
			||||||
    #     @return decorator which supplies the requested parameters
 | 
					 | 
				
			||||||
    #     """
 | 
					 | 
				
			||||||
    #     def decorator(fn):
 | 
					 | 
				
			||||||
    #         pass_token = token
 | 
					 | 
				
			||||||
    #         pass_user_id = user_id
 | 
					 | 
				
			||||||
    #
 | 
					 | 
				
			||||||
    #         @wraps(fn)
 | 
					 | 
				
			||||||
    #         def wrapper(*args, **kargs):
 | 
					 | 
				
			||||||
    #             token = request.headers.get('Authorization', None)
 | 
					 | 
				
			||||||
    #             if token is None:
 | 
					 | 
				
			||||||
    #                 return returns.NO_AUTHORIZATION
 | 
					 | 
				
			||||||
    #             if not token.startswith('Bearer '):
 | 
					 | 
				
			||||||
    #                 return returns.INVALID_AUTHORIZATION
 | 
					 | 
				
			||||||
    #             token = token[7:]
 | 
					 | 
				
			||||||
    #             user_id = ram_db.get_user(token)
 | 
					 | 
				
			||||||
    #             if user_id is None:
 | 
					 | 
				
			||||||
    #                 return returns.INVALID_AUTHORIZATION
 | 
					 | 
				
			||||||
    #
 | 
					 | 
				
			||||||
    #             if pass_user_id and pass_token:
 | 
					 | 
				
			||||||
    #                 return fn(user_id=user_id, token=token, *args, **kargs)
 | 
					 | 
				
			||||||
    #             elif pass_user_id:
 | 
					 | 
				
			||||||
    #                 return fn(user_id=user_id, *args, **kargs)
 | 
					 | 
				
			||||||
    #             elif pass_token:
 | 
					 | 
				
			||||||
    #                 return fn(token=token, *args, **kargs)
 | 
					 | 
				
			||||||
    #             else:
 | 
					 | 
				
			||||||
    #                 return fn(*args, **kargs)
 | 
					 | 
				
			||||||
    #         return wrapper
 | 
					 | 
				
			||||||
    #     return decorator
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
sys.modules[__name__] = Module(__name__)
 | 
					sys.modules[__name__] = Module(__name__)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,6 +31,13 @@ NOT_FOUND = _make_error(
 | 
				
			||||||
    'general/not_found',
 | 
					    'general/not_found',
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def invalid_argument(argname: str) -> tuple[Any, int]:
 | 
				
			||||||
 | 
					    return _make_error(
 | 
				
			||||||
 | 
					        _HTTPStatus.UNPROCESSABLE_ENTITY,
 | 
				
			||||||
 | 
					        'general/invalid_argument',
 | 
				
			||||||
 | 
					        message=f'Invalid argument: {argname}',
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Login
 | 
					# Login
 | 
				
			||||||
 | 
					
 | 
				
			||||||
INVALID_DETAILS = _make_error(
 | 
					INVALID_DETAILS = _make_error(
 | 
				
			||||||
| 
						 | 
					@ -76,3 +83,12 @@ class ErrorSchema(Schema):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SuccessSchema(Schema):
 | 
					class SuccessSchema(Schema):
 | 
				
			||||||
    status = fields.Constant('success')
 | 
					    status = fields.Constant('success')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# smorest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def abort(result: tuple[Any, int]):
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        from flask_smorest import abort as _abort
 | 
				
			||||||
 | 
					        _abort(result[1], response=result)
 | 
				
			||||||
 | 
					    except ImportError:
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,47 +0,0 @@
 | 
				
			||||||
from functools import wraps
 | 
					 | 
				
			||||||
from flask import Blueprint, request
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from pyotp import TOTP
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import db_utils
 | 
					 | 
				
			||||||
from decorators import no_content, ensure_logged_in, user_id, token
 | 
					 | 
				
			||||||
import models
 | 
					 | 
				
			||||||
import ram_db
 | 
					 | 
				
			||||||
import returns
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
login = Blueprint('login', __name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@login.post('/')
 | 
					 | 
				
			||||||
def make_login():
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        username = request.json['username']
 | 
					 | 
				
			||||||
        code = request.json['code']
 | 
					 | 
				
			||||||
    except (TypeError, KeyError):
 | 
					 | 
				
			||||||
        return returns.INVALID_REQUEST
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    user: models.User | None = db_utils.get_user(username=username)
 | 
					 | 
				
			||||||
    if user is None:
 | 
					 | 
				
			||||||
        return returns.INVALID_DETAILS
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    otp = TOTP(user.otp)
 | 
					 | 
				
			||||||
    if not otp.verify(code, valid_window=1):
 | 
					 | 
				
			||||||
        return returns.INVALID_DETAILS
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    token = ram_db.login_user(user.id)
 | 
					 | 
				
			||||||
    return returns.success(token=token)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@login.post('/logout')
 | 
					 | 
				
			||||||
@ensure_logged_in
 | 
					 | 
				
			||||||
@no_content
 | 
					 | 
				
			||||||
def logout():
 | 
					 | 
				
			||||||
    ram_db.logout_user(token)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@login.get('/whoami')
 | 
					 | 
				
			||||||
@ensure_logged_in
 | 
					 | 
				
			||||||
def whoami():
 | 
					 | 
				
			||||||
    user: models.User | None = db_utils.get_user(user_id=user_id)
 | 
					 | 
				
			||||||
    if user is not None:
 | 
					 | 
				
			||||||
        user = user.to_json()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return returns.successs(user=user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue