Initial commit — Stupid Simple Network Inventory
Application web d'inventaire réseau manuel avec FastAPI, Vue 3 et Docker. Inclut l'authentification JWT, la découverte ICMP, et la topologie en cards CSS. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,218 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_SECRET_KEY_FILE = "data/secret_key.txt"
|
||||
_audit = logging.getLogger("audit")
|
||||
|
||||
|
||||
def _log_audit(event: str, **kw) -> None:
|
||||
_audit.info(json.dumps({"event": event, "ts": datetime.now(timezone.utc).isoformat(), **kw}))
|
||||
|
||||
|
||||
def _load_secret_key() -> str:
|
||||
env = os.environ.get("SECRET_KEY")
|
||||
if env:
|
||||
return env
|
||||
if os.path.exists(_SECRET_KEY_FILE):
|
||||
return open(_SECRET_KEY_FILE).read().strip()
|
||||
key = secrets.token_hex(32)
|
||||
os.makedirs(os.path.dirname(_SECRET_KEY_FILE), exist_ok=True)
|
||||
# Create with owner-only permissions (0600) to prevent other users from reading the key
|
||||
fd = os.open(_SECRET_KEY_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(key)
|
||||
except Exception:
|
||||
os.close(fd)
|
||||
raise
|
||||
return key
|
||||
|
||||
|
||||
SECRET_KEY = _load_secret_key()
|
||||
ALGORITHM = "HS256"
|
||||
TOKEN_EXPIRE_HOURS = 24
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||
|
||||
# --- Rate limiting ---
|
||||
_login_attempts: dict[str, list[float]] = {} # username → timestamps
|
||||
_ip_attempts: dict[str, list[float]] = {} # ip → timestamps
|
||||
_rate_lock = threading.Lock()
|
||||
_USERNAME_WINDOW = 900 # 15 min
|
||||
_USERNAME_MAX = 10
|
||||
_IP_WINDOW = 60 # 1 min
|
||||
_IP_MAX = 20
|
||||
|
||||
|
||||
def _check_username_rate_limit(username: str) -> None:
|
||||
now = time.time()
|
||||
with _rate_lock:
|
||||
attempts = [t for t in _login_attempts.get(username, []) if now - t < _USERNAME_WINDOW]
|
||||
if len(attempts) >= _USERNAME_MAX:
|
||||
raise HTTPException(status_code=429, detail="Too many attempts, try again later")
|
||||
attempts.append(now)
|
||||
_login_attempts[username] = attempts
|
||||
|
||||
|
||||
def _check_ip_rate_limit(ip: str) -> None:
|
||||
now = time.time()
|
||||
with _rate_lock:
|
||||
attempts = [t for t in _ip_attempts.get(ip, []) if now - t < _IP_WINDOW]
|
||||
if len(attempts) >= _IP_MAX:
|
||||
raise HTTPException(status_code=429, detail="Too many attempts, try again later")
|
||||
attempts.append(now)
|
||||
_ip_attempts[ip] = attempts
|
||||
|
||||
|
||||
def _clear_login_attempts(username: str) -> None:
|
||||
with _rate_lock:
|
||||
_login_attempts.pop(username, None)
|
||||
|
||||
|
||||
def create_token(username: str, version: int) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=TOKEN_EXPIRE_HOURS)
|
||||
return jwt.encode(
|
||||
{"sub": username, "ver": version, "exp": expire},
|
||||
SECRET_KEY,
|
||||
algorithm=ALGORITHM,
|
||||
)
|
||||
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
token_ver: int = payload.get("ver", 1)
|
||||
if not username:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token", headers={"WWW-Authenticate": "Bearer"})
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
_log_audit("auth.token_rejected", username=username, reason="user_not_found")
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
if (user.token_version or 1) != token_ver:
|
||||
_log_audit("auth.token_rejected", username=username, reason="version_mismatch")
|
||||
raise HTTPException(status_code=401, detail="Session expired, please log in again")
|
||||
return user
|
||||
|
||||
|
||||
def require_password_changed(current_user: User = Depends(get_current_user)) -> User:
|
||||
if current_user.must_change_password:
|
||||
raise HTTPException(status_code=403, detail="Password change required")
|
||||
return current_user
|
||||
|
||||
|
||||
# --- Validation helpers ---
|
||||
_USERNAME_RE = re.compile(r"^[a-zA-Z0-9._-]{1,64}$")
|
||||
|
||||
|
||||
def _validate_new_password(password: str) -> None:
|
||||
if len(password) < 8:
|
||||
raise HTTPException(status_code=400, detail="password_too_short")
|
||||
if not re.search(r"[a-zA-Z]", password) or not re.search(r"[0-9]", password):
|
||||
raise HTTPException(status_code=400, detail="password_too_weak")
|
||||
|
||||
|
||||
def _validate_new_username(username: str) -> None:
|
||||
if not _USERNAME_RE.match(username):
|
||||
raise HTTPException(status_code=400, detail="username_invalid")
|
||||
|
||||
|
||||
class TokenOut(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
username: str
|
||||
must_change_password: bool = False
|
||||
|
||||
|
||||
class AccountUpdate(BaseModel):
|
||||
current_password: str
|
||||
new_username: str | None = None
|
||||
new_password: str | None = None
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenOut)
|
||||
def login(request: Request, form: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
try:
|
||||
_check_ip_rate_limit(client_ip)
|
||||
except HTTPException:
|
||||
_log_audit("auth.login.rate_limited", ip=client_ip, reason="ip")
|
||||
raise
|
||||
try:
|
||||
_check_username_rate_limit(form.username)
|
||||
except HTTPException:
|
||||
_log_audit("auth.login.rate_limited", ip=client_ip, username=form.username, reason="username")
|
||||
raise
|
||||
user = db.query(User).filter(User.username == form.username).first()
|
||||
if not user or not pwd_context.verify(form.password, user.hashed_password):
|
||||
_log_audit("auth.login.failure", username=form.username, ip=client_ip)
|
||||
raise HTTPException(status_code=401, detail="Incorrect username or password")
|
||||
_clear_login_attempts(form.username)
|
||||
_log_audit("auth.login.success", username=user.username, ip=client_ip)
|
||||
return {
|
||||
"access_token": create_token(user.username, user.token_version or 1),
|
||||
"token_type": "bearer",
|
||||
"username": user.username,
|
||||
"must_change_password": bool(user.must_change_password),
|
||||
}
|
||||
|
||||
|
||||
@router.put("/account", response_model=TokenOut)
|
||||
def update_account(
|
||||
data: AccountUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
if not pwd_context.verify(data.current_password, current_user.hashed_password):
|
||||
_log_audit("auth.account.bad_password", username=current_user.username)
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
if data.new_username and data.new_username != current_user.username:
|
||||
_validate_new_username(data.new_username)
|
||||
if db.query(User).filter(User.username == data.new_username).first():
|
||||
raise HTTPException(status_code=400, detail="Username already taken")
|
||||
old_username = current_user.username
|
||||
current_user.username = data.new_username
|
||||
_log_audit("auth.account.username_changed", old_username=old_username, new_username=data.new_username)
|
||||
if data.new_password:
|
||||
_validate_new_password(data.new_password)
|
||||
current_user.hashed_password = pwd_context.hash(data.new_password)
|
||||
current_user.must_change_password = False
|
||||
# Invalidate all previously issued tokens by bumping the version
|
||||
current_user.token_version = (current_user.token_version or 1) + 1
|
||||
_log_audit("auth.account.password_changed", username=current_user.username)
|
||||
db.commit()
|
||||
return {
|
||||
"access_token": create_token(current_user.username, current_user.token_version or 1),
|
||||
"token_type": "bearer",
|
||||
"username": current_user.username,
|
||||
"must_change_password": bool(current_user.must_change_password),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def get_me(current_user: User = Depends(get_current_user)):
|
||||
return {
|
||||
"username": current_user.username,
|
||||
"must_change_password": bool(current_user.must_change_password),
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
import ipaddress
|
||||
import re
|
||||
from typing import Optional, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
import models
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_VALID_TYPES = {
|
||||
"server", "switch", "router", "nas", "gateway", "livebox", "access_point",
|
||||
"camera", "temperature", "sensor", "hub", "smart_plug", "alarm", "light",
|
||||
"doorbell", "desktop", "laptop", "other",
|
||||
}
|
||||
_VALID_VIRT_TYPES = {None, "baremetal", "lxc", "qemu"}
|
||||
|
||||
|
||||
class InterfaceCreate(BaseModel):
|
||||
name: str = "eth0"
|
||||
ip_address: Optional[str] = ""
|
||||
vlan_id: Optional[int] = None
|
||||
is_upstream: bool = False
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def _name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("Interface name cannot be empty")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Interface name too long (max 50 characters)")
|
||||
return v
|
||||
|
||||
@field_validator("ip_address")
|
||||
@classmethod
|
||||
def _ip(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v:
|
||||
try:
|
||||
ipaddress.ip_address(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid IP address: {v!r}")
|
||||
return v
|
||||
|
||||
|
||||
class InterfaceOut(InterfaceCreate):
|
||||
id: int
|
||||
device_id: int
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceCreate(BaseModel):
|
||||
name: str
|
||||
type: str = "other"
|
||||
description: str = ""
|
||||
is_gateway: bool = False
|
||||
is_livebox: bool = False
|
||||
virt_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
interfaces: List[InterfaceCreate] = []
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def _name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("name cannot be empty")
|
||||
if len(v) > 100:
|
||||
raise ValueError("name too long (max 100 characters)")
|
||||
return v
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def _description(cls, v: str) -> str:
|
||||
if len(v) > 500:
|
||||
raise ValueError("description too long (max 500 characters)")
|
||||
return v
|
||||
|
||||
@field_validator("type")
|
||||
@classmethod
|
||||
def _type(cls, v: str) -> str:
|
||||
if v not in _VALID_TYPES:
|
||||
raise ValueError(f"Invalid type: {v!r}. Must be one of: {sorted(_VALID_TYPES)}")
|
||||
return v
|
||||
|
||||
@field_validator("virt_type")
|
||||
@classmethod
|
||||
def _virt_type(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v not in _VALID_VIRT_TYPES:
|
||||
raise ValueError(f"Invalid virt_type: {v!r}. Must be one of: baremetal, lxc, qemu")
|
||||
return v
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def _url(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v:
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https") or not parsed.netloc:
|
||||
raise ValueError("url must be a valid http or https URL")
|
||||
return v
|
||||
|
||||
|
||||
class DeviceOut(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
is_gateway: bool
|
||||
is_livebox: bool
|
||||
virt_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
interfaces: List[InterfaceOut] = []
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
@router.get("/", response_model=List[DeviceOut])
|
||||
def list_devices(db: Session = Depends(get_db)):
|
||||
return db.query(models.Device).order_by(models.Device.name).all()
|
||||
|
||||
|
||||
@router.post("/", response_model=DeviceOut)
|
||||
def create_device(device: DeviceCreate, db: Session = Depends(get_db)):
|
||||
db_device = models.Device(
|
||||
name=device.name,
|
||||
type=device.type,
|
||||
description=device.description,
|
||||
is_gateway=device.is_gateway,
|
||||
is_livebox=device.is_livebox,
|
||||
virt_type=device.virt_type,
|
||||
url=device.url,
|
||||
)
|
||||
db.add(db_device)
|
||||
db.flush()
|
||||
for iface in device.interfaces:
|
||||
db.add(models.DeviceInterface(device_id=db_device.id, **iface.model_dump()))
|
||||
db.commit()
|
||||
db.refresh(db_device)
|
||||
return db_device
|
||||
|
||||
|
||||
@router.put("/{device_id}", response_model=DeviceOut)
|
||||
def update_device(device_id: int, device: DeviceCreate, db: Session = Depends(get_db)):
|
||||
db_device = db.query(models.Device).filter(models.Device.id == device_id).first()
|
||||
if not db_device:
|
||||
raise HTTPException(status_code=404, detail="Équipement introuvable")
|
||||
db_device.name = device.name
|
||||
db_device.type = device.type
|
||||
db_device.description = device.description
|
||||
db_device.is_gateway = device.is_gateway
|
||||
db_device.is_livebox = device.is_livebox
|
||||
db_device.virt_type = device.virt_type
|
||||
db_device.url = device.url
|
||||
db.query(models.DeviceInterface).filter(
|
||||
models.DeviceInterface.device_id == device_id
|
||||
).delete()
|
||||
for iface in device.interfaces:
|
||||
db.add(models.DeviceInterface(device_id=device_id, **iface.model_dump()))
|
||||
db.commit()
|
||||
db.refresh(db_device)
|
||||
return db_device
|
||||
|
||||
|
||||
@router.delete("/{device_id}")
|
||||
def delete_device(device_id: int, db: Session = Depends(get_db)):
|
||||
db_device = db.query(models.Device).filter(models.Device.id == device_id).first()
|
||||
if not db_device:
|
||||
raise HTTPException(status_code=404, detail="Équipement introuvable")
|
||||
db.delete(db_device)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
@@ -0,0 +1,156 @@
|
||||
import ipaddress
|
||||
import subprocess
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Optional
|
||||
|
||||
import dns.resolver
|
||||
import dns.reversename
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_HOSTS_PER_TARGET = 1024 # refuse les /21 et plus larges
|
||||
MAX_HOSTS_TOTAL = 4096 # cap global sur l'ensemble des targets
|
||||
|
||||
|
||||
class ScanTarget(BaseModel):
|
||||
vlan_id: int
|
||||
cidr: str
|
||||
|
||||
|
||||
class ScanRequest(BaseModel):
|
||||
dns_server: str = "8.8.8.8"
|
||||
targets: list[ScanTarget]
|
||||
|
||||
@field_validator("dns_server")
|
||||
@classmethod
|
||||
def _dns_server(cls, v: str) -> str:
|
||||
try:
|
||||
ipaddress.ip_address(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"dns_server must be a valid IP address, got: {v!r}")
|
||||
return v
|
||||
|
||||
|
||||
class DiscoveredHost(BaseModel):
|
||||
ip: str
|
||||
hostname: Optional[str] = None
|
||||
vlan_id: int
|
||||
cidr: str
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
hosts: list[DiscoveredHost]
|
||||
total_scanned: int
|
||||
duration_s: float
|
||||
|
||||
|
||||
def _ping(ip: str) -> bool:
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["ping", "-c", "1", "-W", "1", ip],
|
||||
capture_output=True,
|
||||
timeout=3,
|
||||
)
|
||||
return r.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _ptr_lookup(ip: str, nameserver: str) -> Optional[str]:
|
||||
try:
|
||||
resolver = dns.resolver.Resolver(configure=False)
|
||||
resolver.nameservers = [nameserver]
|
||||
resolver.timeout = 1
|
||||
resolver.lifetime = 2
|
||||
rev = dns.reversename.from_address(ip)
|
||||
ans = resolver.resolve(rev, "PTR")
|
||||
return str(ans[0]).rstrip(".")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _scan_one(ip: str, dns_server: str, vlan_id: int, cidr: str) -> Optional[DiscoveredHost]:
|
||||
if not _ping(ip):
|
||||
return None
|
||||
hostname = _ptr_lookup(ip, dns_server)
|
||||
return DiscoveredHost(ip=ip, hostname=hostname, vlan_id=vlan_id, cidr=cidr)
|
||||
|
||||
|
||||
class PingRequest(BaseModel):
|
||||
ips: list[str]
|
||||
|
||||
@field_validator("ips")
|
||||
@classmethod
|
||||
def _ips(cls, v: list[str]) -> list[str]:
|
||||
for ip in v:
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid IP address: {ip!r}")
|
||||
return v
|
||||
|
||||
|
||||
class PingResult(BaseModel):
|
||||
ip: str
|
||||
alive: bool
|
||||
|
||||
|
||||
@router.post("/ping", response_model=list[PingResult])
|
||||
def ping_many(req: PingRequest):
|
||||
if not req.ips:
|
||||
return []
|
||||
with ThreadPoolExecutor(max_workers=50) as pool:
|
||||
futures = {pool.submit(_ping, ip): ip for ip in req.ips}
|
||||
results = [PingResult(ip=futures[f], alive=f.result()) for f in as_completed(futures)]
|
||||
return results
|
||||
|
||||
|
||||
@router.post("/scan", response_model=ScanResponse)
|
||||
def scan(req: ScanRequest):
|
||||
tasks: list[tuple[str, str, int, str]] = []
|
||||
|
||||
for t in req.targets:
|
||||
try:
|
||||
net = ipaddress.ip_network(t.cidr, strict=False)
|
||||
except ValueError:
|
||||
raise HTTPException(400, f"CIDR invalide : {t.cidr}")
|
||||
|
||||
hosts = list(net.hosts())
|
||||
if len(hosts) > MAX_HOSTS_PER_TARGET:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"Réseau {t.cidr} trop large ({len(hosts)} hôtes). "
|
||||
f"Maximum par target : {MAX_HOSTS_PER_TARGET} hôtes (/22 ou plus petit).",
|
||||
)
|
||||
for ip in hosts:
|
||||
tasks.append((str(ip), req.dns_server, t.vlan_id, t.cidr))
|
||||
|
||||
if not tasks:
|
||||
raise HTTPException(400, "Aucune cible à scanner.")
|
||||
|
||||
if len(tasks) > MAX_HOSTS_TOTAL:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"Trop d'hôtes au total ({len(tasks)}). Maximum global : {MAX_HOSTS_TOTAL}.",
|
||||
)
|
||||
|
||||
t0 = time.time()
|
||||
results: list[DiscoveredHost] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=100) as pool:
|
||||
futures = [pool.submit(_scan_one, *args) for args in tasks]
|
||||
for f in as_completed(futures):
|
||||
host = f.result()
|
||||
if host:
|
||||
results.append(host)
|
||||
|
||||
results.sort(key=lambda h: ipaddress.ip_address(h.ip))
|
||||
|
||||
return ScanResponse(
|
||||
hosts=results,
|
||||
total_scanned=len(tasks),
|
||||
duration_s=round(time.time() - t0, 1),
|
||||
)
|
||||
@@ -0,0 +1,102 @@
|
||||
import ipaddress
|
||||
import re
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
import models
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_COLOR_RE = re.compile(r"^#[0-9a-fA-F]{6}$")
|
||||
|
||||
|
||||
class VlanCreate(BaseModel):
|
||||
vlan_id: Optional[int] = None
|
||||
name: str
|
||||
cidr: Optional[str] = ""
|
||||
color: str = "#4A90D9"
|
||||
|
||||
@field_validator("vlan_id")
|
||||
@classmethod
|
||||
def _vlan_id(cls, v: Optional[int]) -> Optional[int]:
|
||||
if v is not None and not (1 <= v <= 4094):
|
||||
raise ValueError("vlan_id must be between 1 and 4094")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def _name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("name cannot be empty")
|
||||
if len(v) > 100:
|
||||
raise ValueError("name too long (max 100 characters)")
|
||||
return v
|
||||
|
||||
@field_validator("cidr")
|
||||
@classmethod
|
||||
def _cidr(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v:
|
||||
try:
|
||||
ipaddress.ip_network(v, strict=False)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid CIDR notation: {v!r}")
|
||||
return v
|
||||
|
||||
@field_validator("color")
|
||||
@classmethod
|
||||
def _color(cls, v: str) -> str:
|
||||
if not _COLOR_RE.match(v):
|
||||
raise ValueError("color must be a 6-digit hex color (e.g. #4A90D9)")
|
||||
return v
|
||||
|
||||
|
||||
class VlanOut(VlanCreate):
|
||||
id: int
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
@router.get("/", response_model=List[VlanOut])
|
||||
def list_vlans(db: Session = Depends(get_db)):
|
||||
return db.query(models.Vlan).order_by(nullsfirst(models.Vlan.vlan_id)).all()
|
||||
|
||||
|
||||
@router.post("/", response_model=VlanOut)
|
||||
def create_vlan(vlan: VlanCreate, db: Session = Depends(get_db)):
|
||||
if vlan.vlan_id is not None:
|
||||
existing = db.query(models.Vlan).filter(models.Vlan.vlan_id == vlan.vlan_id).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"VLAN {vlan.vlan_id} existe déjà")
|
||||
db_vlan = models.Vlan(**vlan.model_dump())
|
||||
db.add(db_vlan)
|
||||
db.commit()
|
||||
db.refresh(db_vlan)
|
||||
return db_vlan
|
||||
|
||||
|
||||
@router.put("/{vlan_pk}", response_model=VlanOut)
|
||||
def update_vlan(vlan_pk: int, vlan: VlanCreate, db: Session = Depends(get_db)):
|
||||
db_vlan = db.query(models.Vlan).filter(models.Vlan.id == vlan_pk).first()
|
||||
if not db_vlan:
|
||||
raise HTTPException(status_code=404, detail="VLAN introuvable")
|
||||
for k, v in vlan.model_dump().items():
|
||||
setattr(db_vlan, k, v)
|
||||
db.commit()
|
||||
db.refresh(db_vlan)
|
||||
return db_vlan
|
||||
|
||||
|
||||
@router.delete("/{vlan_pk}")
|
||||
def delete_vlan(vlan_pk: int, db: Session = Depends(get_db)):
|
||||
db_vlan = db.query(models.Vlan).filter(models.Vlan.id == vlan_pk).first()
|
||||
if not db_vlan:
|
||||
raise HTTPException(status_code=404, detail="VLAN introuvable")
|
||||
db.delete(db_vlan)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
Reference in New Issue
Block a user