From 3448e3004894cef8d46fff3e2ebac22ec8e3552f Mon Sep 17 00:00:00 2001 From: Stefan Nilsson Date: Wed, 12 Feb 2025 16:49:46 +0100 Subject: [PATCH] Improved error handling --- main.py | 147 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 98 insertions(+), 49 deletions(-) diff --git a/main.py b/main.py index 5a30d8d..4b0b183 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ -from fastapi import FastAPI, HTTPException -from fastapi.responses import HTMLResponse +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import HTMLResponse, JSONResponse import requests import os import subprocess @@ -27,62 +27,95 @@ def init_db(): init_db() +def execute_query(query, params=()): + try: + conn = sqlite3.connect("servers.db") + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + conn.close() + except sqlite3.Error as e: + raise DatabaseError(str(e)) + +def fetch_query(query, params=()): + try: + conn = sqlite3.connect("servers.db") + cursor = conn.cursor() + cursor.execute(query, params) + result = cursor.fetchall() + conn.close() + return result + except sqlite3.Error as e: + raise DatabaseError(str(e)) + def get_server_info(server_name): - conn = sqlite3.connect("servers.db") - cursor = conn.cursor() - cursor.execute("SELECT servers.ip, servers.mac, clusters.api_token FROM servers JOIN clusters ON servers.cluster_name = clusters.name WHERE servers.name = ?", (server_name,)) - server = cursor.fetchone() - conn.close() - if server: - return {"ip": server[0], "mac": server[1], "api_token": server[2]} + result = fetch_query( + "SELECT servers.ip, servers.mac, clusters.api_token FROM servers JOIN clusters ON servers.cluster_name = clusters.name WHERE servers.name = ?", + (server_name,) + ) + if result: + return {"ip": result[0][0], "mac": result[0][1], "api_token": result[0][2]} return None def list_servers(): - conn = sqlite3.connect("servers.db") - cursor = conn.cursor() - cursor.execute("SELECT servers.name, servers.ip, servers.mac, servers.cluster_name FROM servers") - servers = cursor.fetchall() - conn.close() - return [{"name": s[0], "ip": s[1], "mac": s[2], "cluster_name": s[3]} for s in servers] + servers = fetch_query( + "SELECT servers.name, servers.ip, servers.mac, servers.cluster_name FROM servers" + ) + if servers: + return [{"name": s[0], "ip": s[1], "mac": s[2], "cluster_name": s[3]} for s in servers] + return None def list_clusters(): - conn = sqlite3.connect("servers.db") - cursor = conn.cursor() - cursor.execute("SELECT name FROM clusters") - clusters = cursor.fetchall() - conn.close() - return [c[0] for c in clusters] + clusters = fetch_query("SELECT name FROM clusters") + if clusters: + return [c[0] for c in clusters] + return None def add_or_update_cluster(name: str, api_token: str): - conn = sqlite3.connect("servers.db") - cursor = conn.cursor() - cursor.execute("REPLACE INTO clusters (name, api_token) VALUES (?, ?)", (name, api_token)) - conn.commit() - conn.close() + execute_query("REPLACE INTO clusters (name, api_token) VALUES (?, ?)", (name, api_token)) def delete_cluster(name: str): - conn = sqlite3.connect("servers.db") - cursor = conn.cursor() - cursor.execute("DELETE FROM clusters WHERE name = ?", (name,)) - conn.commit() - conn.close() + execute_query("DELETE FROM clusters WHERE name = ?", (name,)) def add_or_update_server(name: str, ip: str, mac: str, cluster_name: str): - conn = sqlite3.connect("servers.db") - cursor = conn.cursor() - cursor.execute("REPLACE INTO servers (name, ip, mac, cluster_name) VALUES (?, ?, ?, ?)", (name, ip, mac, cluster_name)) - conn.commit() - conn.close() + execute_query("REPLACE INTO servers (name, ip, mac, cluster_name) VALUES (?, ?, ?, ?)", (name, ip, mac, cluster_name)) def delete_server(name: str): - conn = sqlite3.connect("servers.db") - cursor = conn.cursor() - cursor.execute("DELETE FROM servers WHERE name = ?", (name,)) - conn.commit() - conn.close() + execute_query("DELETE FROM servers WHERE name = ?", (name,)) app = FastAPI() +class DatabaseError(Exception): + """Custom exception for database errors.""" + def __init__(self, message: str): + self.message = message + +class ProxmoxAPIError(Exception): + """Custom exception for Proxmox API errors.""" + def __init__(self, message: str): + self.message = message + +@app.exception_handler(DatabaseError) +def database_error_handler(request: Request, exc: DatabaseError): + return JSONResponse( + status_code=500, + content={"error": "Database Error", "message": exc.message}, + ) + +@app.exception_handler(ProxmoxAPIError) +def proxmox_api_error_handler(request: Request, exc: ProxmoxAPIError): + return JSONResponse( + status_code=500, + content={"error": "Proxmox API Error", "message": exc.message}, + ) + +@app.exception_handler(Exception) +def generic_exception_handler(request: Request, exc: Exception): + return JSONResponse( + status_code=500, + content={"error": "Internal Server Error", "message": str(exc)}, + ) + class ServerModel(BaseModel): name: str ip: str @@ -101,8 +134,12 @@ def get_proxmox_status(server_ip, api_token): response.raise_for_status() data = response.json() return {"status": data["data"][0]["status"]} # Extract node status - except requests.RequestException as e: - return {"status": "unknown", "error": str(e)} + except requests.exceptions.Timeout: + raise ProxmoxAPIError("Connection to Proxmox server timed out") + except requests.exceptions.ConnectionError: + raise ProxmoxAPIError("Failed to connect to Proxmox server") + except requests.exceptions.RequestException as e: + raise ProxmoxAPIError(str(e)) @app.get("/openapi.json") def get_openapi_spec(): @@ -139,8 +176,12 @@ def list_all_statuses(): servers = list_servers() statuses = {} for server in servers: - cluster_token = get_server_info(server["name"])["api_token"] - statuses[server["name"]] = get_proxmox_status(server["ip"], cluster_token) + try: + cluster_token = get_server_info(server["name"])["api_token"] + statuses[server["name"]] = get_proxmox_status(server["ip"], cluster_token) + except ProxmoxAPIError as e: + statuses[server["name"]] = {"status": "unknown", "error": str(e)} + return statuses @app.get("/clusters") @@ -177,7 +218,15 @@ def get_power_status(server_name: str): def list_all_states(): """Returns the power states of all servers.""" servers = list_servers() - return {server["name"]: check_power_state(server["ip"]) for server in servers} + power_states = {} + + for server in servers: + try: + power_states[server["name"]] = check_power_state(server["ip"]) + except Exception as e: + power_states[server["name"]] = {"power": "unknown", "error": str(e)} + + return power_states @app.put("/states/{server_name}") @app.patch("/states/{server_name}") @@ -189,15 +238,15 @@ def control_power(server_name: str, state: str): if state == "on": if check_power_state(server["ip"]) == "off": - send_magic_packet(server["mac"]) + send_magic_packet(server["mac"], ip_address="192.168.1.255") return {"message": "Wake-on-LAN signal sent"} - return {"message": "Server is already on"} + raise HTTPException(status_code=400, detail="Server is already on") elif state == "off": try: requests.post(f"https://{server["ip"]}:8006/api2/json/nodes/shutdown", verify=False) return {"message": "Shutdown command sent"} except requests.RequestException as e: - raise HTTPException(status_code=500, detail=str(e)) + raise ProxmoxAPIError(f"Failed to send shutdown command: {str(e)}") else: raise HTTPException(status_code=400, detail="Invalid state. Use 'on' or 'off'")