Initial commit — Stupid Simple Network Inventory
Application web d'inventaire réseau manuel avec FastAPI, Vue 3 et Docker. Inclut l'authentification JWT, la découverte ICMP, et la topologie en cards CSS. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends iputils-ping \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
RUN mkdir -p data
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,35 @@
|
||||
import sqlite3
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
import os
|
||||
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
||||
_DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:///./data/topology.db")
|
||||
|
||||
engine = create_engine(
|
||||
_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(Engine, "connect")
|
||||
def _set_sqlite_pragma(dbapi_conn, _record):
|
||||
if isinstance(dbapi_conn, sqlite3.Connection):
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
+201
@@ -0,0 +1,201 @@
|
||||
import os
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy import text
|
||||
from database import engine, Base
|
||||
from routers import vlans, devices, discovery
|
||||
from routers.auth import router as auth_router, get_current_user, require_password_changed
|
||||
|
||||
|
||||
def _migrate_vlan_nullable():
|
||||
"""Make vlans.vlan_id nullable (SQLite can't ALTER COLUMN, so recreate)."""
|
||||
with engine.connect() as conn:
|
||||
if not conn.execute(text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='vlans'"
|
||||
)).fetchone():
|
||||
return
|
||||
cols = conn.execute(text("PRAGMA table_info(vlans)")).fetchall()
|
||||
if not any(row[1] == 'vlan_id' and row[3] == 1 for row in cols):
|
||||
return
|
||||
conn.execute(text("PRAGMA foreign_keys=OFF"))
|
||||
conn.execute(text("""
|
||||
CREATE TABLE vlans_new (
|
||||
id INTEGER NOT NULL PRIMARY KEY,
|
||||
vlan_id INTEGER UNIQUE,
|
||||
name VARCHAR NOT NULL,
|
||||
cidr VARCHAR,
|
||||
color VARCHAR
|
||||
)
|
||||
"""))
|
||||
conn.execute(text("INSERT INTO vlans_new SELECT id, vlan_id, name, cidr, color FROM vlans"))
|
||||
conn.execute(text("DROP TABLE vlans"))
|
||||
conn.execute(text("ALTER TABLE vlans_new RENAME TO vlans"))
|
||||
conn.commit()
|
||||
conn.execute(text("PRAGMA foreign_keys=ON"))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _migrate_device_virt_type():
|
||||
"""Ajoute la colonne virt_type sur devices si absente."""
|
||||
with engine.connect() as conn:
|
||||
if not conn.execute(text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='devices'"
|
||||
)).fetchone():
|
||||
return
|
||||
cols = [row[1] for row in conn.execute(text("PRAGMA table_info(devices)")).fetchall()]
|
||||
if 'virt_type' not in cols:
|
||||
conn.execute(text("ALTER TABLE devices ADD COLUMN virt_type VARCHAR"))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _migrate_device_url():
|
||||
"""Ajoute la colonne url sur devices si absente."""
|
||||
with engine.connect() as conn:
|
||||
if not conn.execute(text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='devices'"
|
||||
)).fetchone():
|
||||
return
|
||||
cols = [row[1] for row in conn.execute(text("PRAGMA table_info(devices)")).fetchall()]
|
||||
if 'url' not in cols:
|
||||
conn.execute(text("ALTER TABLE devices ADD COLUMN url VARCHAR"))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _migrate_users_must_change_password():
|
||||
"""Ajoute la colonne must_change_password sur users si absente."""
|
||||
with engine.connect() as conn:
|
||||
if not conn.execute(text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
||||
)).fetchone():
|
||||
return
|
||||
cols = [row[1] for row in conn.execute(text("PRAGMA table_info(users)")).fetchall()]
|
||||
if 'must_change_password' not in cols:
|
||||
conn.execute(text(
|
||||
"ALTER TABLE users ADD COLUMN must_change_password BOOLEAN NOT NULL DEFAULT 0"
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _migrate_users_token_version():
|
||||
"""Ajoute la colonne token_version sur users si absente."""
|
||||
with engine.connect() as conn:
|
||||
if not conn.execute(text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
||||
)).fetchone():
|
||||
return
|
||||
cols = [row[1] for row in conn.execute(text("PRAGMA table_info(users)")).fetchall()]
|
||||
if 'token_version' not in cols:
|
||||
conn.execute(text(
|
||||
"ALTER TABLE users ADD COLUMN token_version INTEGER NOT NULL DEFAULT 1"
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _migrate_force_admin_password_change():
|
||||
"""Force must_change_password=1 pour admin utilisant encore le mot de passe bootstrap."""
|
||||
from passlib.context import CryptContext
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
with engine.connect() as conn:
|
||||
if not conn.execute(text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
||||
)).fetchone():
|
||||
return
|
||||
row = conn.execute(text(
|
||||
"SELECT hashed_password FROM users WHERE username='admin' AND must_change_password=0"
|
||||
)).fetchone()
|
||||
if row and pwd_context.verify("admin", row[0]):
|
||||
conn.execute(text(
|
||||
"UPDATE users SET must_change_password=1 WHERE username='admin'"
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _migrate_users():
|
||||
"""Crée la table users et le compte admin par défaut si absents."""
|
||||
from passlib.context import CryptContext
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
initial_password = os.environ.get("INITIAL_ADMIN_PASSWORD", "")
|
||||
with engine.connect() as conn:
|
||||
table_exists = conn.execute(text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
||||
)).fetchone()
|
||||
if not table_exists:
|
||||
conn.execute(text("""
|
||||
CREATE TABLE users (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
username VARCHAR NOT NULL UNIQUE,
|
||||
hashed_password VARCHAR NOT NULL,
|
||||
must_change_password BOOLEAN NOT NULL DEFAULT 0,
|
||||
token_version INTEGER NOT NULL DEFAULT 1
|
||||
)
|
||||
"""))
|
||||
conn.commit()
|
||||
count = conn.execute(text("SELECT COUNT(*) FROM users")).fetchone()[0]
|
||||
if count == 0:
|
||||
if initial_password:
|
||||
hashed = pwd_context.hash(initial_password)
|
||||
must_change = 0
|
||||
else:
|
||||
hashed = pwd_context.hash("admin")
|
||||
must_change = 1
|
||||
conn.execute(
|
||||
text("INSERT INTO users (username, hashed_password, must_change_password) VALUES ('admin', :h, :m)"),
|
||||
{"h": hashed, "m": must_change},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _migrate_drop_links_table():
|
||||
"""Supprime la table links (fonctionnalité retirée en phase 3). Idempotent."""
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("PRAGMA foreign_keys=OFF"))
|
||||
if conn.execute(text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='links'"
|
||||
)).fetchone():
|
||||
conn.execute(text("DROP TABLE links"))
|
||||
conn.commit()
|
||||
conn.execute(text("PRAGMA foreign_keys=ON"))
|
||||
conn.commit()
|
||||
|
||||
|
||||
_migrate_vlan_nullable()
|
||||
_migrate_device_virt_type()
|
||||
_migrate_device_url()
|
||||
_migrate_users_must_change_password()
|
||||
_migrate_users_token_version()
|
||||
_migrate_force_admin_password_change()
|
||||
_migrate_drop_links_table()
|
||||
_migrate_users()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
app = FastAPI(title="Network Topology Manager")
|
||||
|
||||
# CORS — configurable via ALLOWED_ORIGINS env var (comma-separated).
|
||||
# Default "*" for backward compatibility in a behind-proxy deployment.
|
||||
# Production: set ALLOWED_ORIGINS="" to disable, or "https://yourdomain.com".
|
||||
_allowed_origins_env = os.environ.get("ALLOWED_ORIGINS", "*")
|
||||
if _allowed_origins_env.strip() == "*":
|
||||
_origins = ["*"]
|
||||
elif _allowed_origins_env.strip() == "":
|
||||
_origins = []
|
||||
else:
|
||||
_origins = [o.strip() for o in _allowed_origins_env.split(",") if o.strip()]
|
||||
|
||||
if _origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_origins,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
||||
app.include_router(vlans.router, prefix="/api/vlans", tags=["vlans"], dependencies=[Depends(require_password_changed)])
|
||||
app.include_router(devices.router, prefix="/api/devices", tags=["devices"], dependencies=[Depends(require_password_changed)])
|
||||
app.include_router(discovery.router, prefix="/api/discovery", tags=["discovery"], dependencies=[Depends(require_password_changed)])
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
@@ -0,0 +1,58 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String, unique=True, nullable=False)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
must_change_password = Column(Boolean, nullable=False, default=False, server_default="0")
|
||||
token_version = Column(Integer, nullable=False, default=1, server_default="1")
|
||||
|
||||
|
||||
class Vlan(Base):
|
||||
__tablename__ = "vlans"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
vlan_id = Column(Integer, unique=True, nullable=True)
|
||||
name = Column(String, nullable=False)
|
||||
cidr = Column(String, nullable=True, default="")
|
||||
color = Column(String, default="#4A90D9")
|
||||
|
||||
interfaces = relationship("DeviceInterface", back_populates="vlan")
|
||||
|
||||
|
||||
class Device(Base):
|
||||
__tablename__ = "devices"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
type = Column(String, default="other")
|
||||
description = Column(String, default="")
|
||||
is_gateway = Column(Boolean, default=False)
|
||||
is_livebox = Column(Boolean, default=False)
|
||||
virt_type = Column(String, nullable=True)
|
||||
url = Column(String, nullable=True)
|
||||
|
||||
interfaces = relationship(
|
||||
"DeviceInterface", back_populates="device", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class DeviceInterface(Base):
|
||||
__tablename__ = "device_interfaces"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
device_id = Column(Integer, ForeignKey("devices.id"), nullable=False)
|
||||
vlan_id = Column(Integer, ForeignKey("vlans.id"), nullable=True)
|
||||
ip_address = Column(String, nullable=True, default="")
|
||||
name = Column(String, default="eth0")
|
||||
is_upstream = Column(Boolean, default=False)
|
||||
|
||||
device = relationship("Device", back_populates="interfaces")
|
||||
vlan = relationship("Vlan", back_populates="interfaces")
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
pytest>=7.4
|
||||
httpx>=0.25
|
||||
@@ -0,0 +1,9 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
sqlalchemy==2.0.23
|
||||
pydantic==2.5.0
|
||||
python-multipart==0.0.6
|
||||
dnspython==2.4.2
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
bcrypt==3.2.2
|
||||
@@ -0,0 +1,218 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_SECRET_KEY_FILE = "data/secret_key.txt"
|
||||
_audit = logging.getLogger("audit")
|
||||
|
||||
|
||||
def _log_audit(event: str, **kw) -> None:
|
||||
_audit.info(json.dumps({"event": event, "ts": datetime.now(timezone.utc).isoformat(), **kw}))
|
||||
|
||||
|
||||
def _load_secret_key() -> str:
|
||||
env = os.environ.get("SECRET_KEY")
|
||||
if env:
|
||||
return env
|
||||
if os.path.exists(_SECRET_KEY_FILE):
|
||||
return open(_SECRET_KEY_FILE).read().strip()
|
||||
key = secrets.token_hex(32)
|
||||
os.makedirs(os.path.dirname(_SECRET_KEY_FILE), exist_ok=True)
|
||||
# Create with owner-only permissions (0600) to prevent other users from reading the key
|
||||
fd = os.open(_SECRET_KEY_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(key)
|
||||
except Exception:
|
||||
os.close(fd)
|
||||
raise
|
||||
return key
|
||||
|
||||
|
||||
SECRET_KEY = _load_secret_key()
|
||||
ALGORITHM = "HS256"
|
||||
TOKEN_EXPIRE_HOURS = 24
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||
|
||||
# --- Rate limiting ---
|
||||
_login_attempts: dict[str, list[float]] = {} # username → timestamps
|
||||
_ip_attempts: dict[str, list[float]] = {} # ip → timestamps
|
||||
_rate_lock = threading.Lock()
|
||||
_USERNAME_WINDOW = 900 # 15 min
|
||||
_USERNAME_MAX = 10
|
||||
_IP_WINDOW = 60 # 1 min
|
||||
_IP_MAX = 20
|
||||
|
||||
|
||||
def _check_username_rate_limit(username: str) -> None:
|
||||
now = time.time()
|
||||
with _rate_lock:
|
||||
attempts = [t for t in _login_attempts.get(username, []) if now - t < _USERNAME_WINDOW]
|
||||
if len(attempts) >= _USERNAME_MAX:
|
||||
raise HTTPException(status_code=429, detail="Too many attempts, try again later")
|
||||
attempts.append(now)
|
||||
_login_attempts[username] = attempts
|
||||
|
||||
|
||||
def _check_ip_rate_limit(ip: str) -> None:
|
||||
now = time.time()
|
||||
with _rate_lock:
|
||||
attempts = [t for t in _ip_attempts.get(ip, []) if now - t < _IP_WINDOW]
|
||||
if len(attempts) >= _IP_MAX:
|
||||
raise HTTPException(status_code=429, detail="Too many attempts, try again later")
|
||||
attempts.append(now)
|
||||
_ip_attempts[ip] = attempts
|
||||
|
||||
|
||||
def _clear_login_attempts(username: str) -> None:
|
||||
with _rate_lock:
|
||||
_login_attempts.pop(username, None)
|
||||
|
||||
|
||||
def create_token(username: str, version: int) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=TOKEN_EXPIRE_HOURS)
|
||||
return jwt.encode(
|
||||
{"sub": username, "ver": version, "exp": expire},
|
||||
SECRET_KEY,
|
||||
algorithm=ALGORITHM,
|
||||
)
|
||||
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
token_ver: int = payload.get("ver", 1)
|
||||
if not username:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token", headers={"WWW-Authenticate": "Bearer"})
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
_log_audit("auth.token_rejected", username=username, reason="user_not_found")
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
if (user.token_version or 1) != token_ver:
|
||||
_log_audit("auth.token_rejected", username=username, reason="version_mismatch")
|
||||
raise HTTPException(status_code=401, detail="Session expired, please log in again")
|
||||
return user
|
||||
|
||||
|
||||
def require_password_changed(current_user: User = Depends(get_current_user)) -> User:
|
||||
if current_user.must_change_password:
|
||||
raise HTTPException(status_code=403, detail="Password change required")
|
||||
return current_user
|
||||
|
||||
|
||||
# --- Validation helpers ---
|
||||
_USERNAME_RE = re.compile(r"^[a-zA-Z0-9._-]{1,64}$")
|
||||
|
||||
|
||||
def _validate_new_password(password: str) -> None:
|
||||
if len(password) < 8:
|
||||
raise HTTPException(status_code=400, detail="password_too_short")
|
||||
if not re.search(r"[a-zA-Z]", password) or not re.search(r"[0-9]", password):
|
||||
raise HTTPException(status_code=400, detail="password_too_weak")
|
||||
|
||||
|
||||
def _validate_new_username(username: str) -> None:
|
||||
if not _USERNAME_RE.match(username):
|
||||
raise HTTPException(status_code=400, detail="username_invalid")
|
||||
|
||||
|
||||
class TokenOut(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
username: str
|
||||
must_change_password: bool = False
|
||||
|
||||
|
||||
class AccountUpdate(BaseModel):
|
||||
current_password: str
|
||||
new_username: str | None = None
|
||||
new_password: str | None = None
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenOut)
|
||||
def login(request: Request, form: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
try:
|
||||
_check_ip_rate_limit(client_ip)
|
||||
except HTTPException:
|
||||
_log_audit("auth.login.rate_limited", ip=client_ip, reason="ip")
|
||||
raise
|
||||
try:
|
||||
_check_username_rate_limit(form.username)
|
||||
except HTTPException:
|
||||
_log_audit("auth.login.rate_limited", ip=client_ip, username=form.username, reason="username")
|
||||
raise
|
||||
user = db.query(User).filter(User.username == form.username).first()
|
||||
if not user or not pwd_context.verify(form.password, user.hashed_password):
|
||||
_log_audit("auth.login.failure", username=form.username, ip=client_ip)
|
||||
raise HTTPException(status_code=401, detail="Incorrect username or password")
|
||||
_clear_login_attempts(form.username)
|
||||
_log_audit("auth.login.success", username=user.username, ip=client_ip)
|
||||
return {
|
||||
"access_token": create_token(user.username, user.token_version or 1),
|
||||
"token_type": "bearer",
|
||||
"username": user.username,
|
||||
"must_change_password": bool(user.must_change_password),
|
||||
}
|
||||
|
||||
|
||||
@router.put("/account", response_model=TokenOut)
|
||||
def update_account(
|
||||
data: AccountUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
if not pwd_context.verify(data.current_password, current_user.hashed_password):
|
||||
_log_audit("auth.account.bad_password", username=current_user.username)
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
if data.new_username and data.new_username != current_user.username:
|
||||
_validate_new_username(data.new_username)
|
||||
if db.query(User).filter(User.username == data.new_username).first():
|
||||
raise HTTPException(status_code=400, detail="Username already taken")
|
||||
old_username = current_user.username
|
||||
current_user.username = data.new_username
|
||||
_log_audit("auth.account.username_changed", old_username=old_username, new_username=data.new_username)
|
||||
if data.new_password:
|
||||
_validate_new_password(data.new_password)
|
||||
current_user.hashed_password = pwd_context.hash(data.new_password)
|
||||
current_user.must_change_password = False
|
||||
# Invalidate all previously issued tokens by bumping the version
|
||||
current_user.token_version = (current_user.token_version or 1) + 1
|
||||
_log_audit("auth.account.password_changed", username=current_user.username)
|
||||
db.commit()
|
||||
return {
|
||||
"access_token": create_token(current_user.username, current_user.token_version or 1),
|
||||
"token_type": "bearer",
|
||||
"username": current_user.username,
|
||||
"must_change_password": bool(current_user.must_change_password),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def get_me(current_user: User = Depends(get_current_user)):
|
||||
return {
|
||||
"username": current_user.username,
|
||||
"must_change_password": bool(current_user.must_change_password),
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
import ipaddress
|
||||
import re
|
||||
from typing import Optional, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
import models
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_VALID_TYPES = {
|
||||
"server", "switch", "router", "nas", "gateway", "livebox", "access_point",
|
||||
"camera", "temperature", "sensor", "hub", "smart_plug", "alarm", "light",
|
||||
"doorbell", "desktop", "laptop", "other",
|
||||
}
|
||||
_VALID_VIRT_TYPES = {None, "baremetal", "lxc", "qemu"}
|
||||
|
||||
|
||||
class InterfaceCreate(BaseModel):
|
||||
name: str = "eth0"
|
||||
ip_address: Optional[str] = ""
|
||||
vlan_id: Optional[int] = None
|
||||
is_upstream: bool = False
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def _name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("Interface name cannot be empty")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Interface name too long (max 50 characters)")
|
||||
return v
|
||||
|
||||
@field_validator("ip_address")
|
||||
@classmethod
|
||||
def _ip(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v:
|
||||
try:
|
||||
ipaddress.ip_address(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid IP address: {v!r}")
|
||||
return v
|
||||
|
||||
|
||||
class InterfaceOut(InterfaceCreate):
|
||||
id: int
|
||||
device_id: int
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceCreate(BaseModel):
|
||||
name: str
|
||||
type: str = "other"
|
||||
description: str = ""
|
||||
is_gateway: bool = False
|
||||
is_livebox: bool = False
|
||||
virt_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
interfaces: List[InterfaceCreate] = []
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def _name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("name cannot be empty")
|
||||
if len(v) > 100:
|
||||
raise ValueError("name too long (max 100 characters)")
|
||||
return v
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def _description(cls, v: str) -> str:
|
||||
if len(v) > 500:
|
||||
raise ValueError("description too long (max 500 characters)")
|
||||
return v
|
||||
|
||||
@field_validator("type")
|
||||
@classmethod
|
||||
def _type(cls, v: str) -> str:
|
||||
if v not in _VALID_TYPES:
|
||||
raise ValueError(f"Invalid type: {v!r}. Must be one of: {sorted(_VALID_TYPES)}")
|
||||
return v
|
||||
|
||||
@field_validator("virt_type")
|
||||
@classmethod
|
||||
def _virt_type(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v not in _VALID_VIRT_TYPES:
|
||||
raise ValueError(f"Invalid virt_type: {v!r}. Must be one of: baremetal, lxc, qemu")
|
||||
return v
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def _url(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v:
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https") or not parsed.netloc:
|
||||
raise ValueError("url must be a valid http or https URL")
|
||||
return v
|
||||
|
||||
|
||||
class DeviceOut(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
is_gateway: bool
|
||||
is_livebox: bool
|
||||
virt_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
interfaces: List[InterfaceOut] = []
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
@router.get("/", response_model=List[DeviceOut])
|
||||
def list_devices(db: Session = Depends(get_db)):
|
||||
return db.query(models.Device).order_by(models.Device.name).all()
|
||||
|
||||
|
||||
@router.post("/", response_model=DeviceOut)
|
||||
def create_device(device: DeviceCreate, db: Session = Depends(get_db)):
|
||||
db_device = models.Device(
|
||||
name=device.name,
|
||||
type=device.type,
|
||||
description=device.description,
|
||||
is_gateway=device.is_gateway,
|
||||
is_livebox=device.is_livebox,
|
||||
virt_type=device.virt_type,
|
||||
url=device.url,
|
||||
)
|
||||
db.add(db_device)
|
||||
db.flush()
|
||||
for iface in device.interfaces:
|
||||
db.add(models.DeviceInterface(device_id=db_device.id, **iface.model_dump()))
|
||||
db.commit()
|
||||
db.refresh(db_device)
|
||||
return db_device
|
||||
|
||||
|
||||
@router.put("/{device_id}", response_model=DeviceOut)
|
||||
def update_device(device_id: int, device: DeviceCreate, db: Session = Depends(get_db)):
|
||||
db_device = db.query(models.Device).filter(models.Device.id == device_id).first()
|
||||
if not db_device:
|
||||
raise HTTPException(status_code=404, detail="Équipement introuvable")
|
||||
db_device.name = device.name
|
||||
db_device.type = device.type
|
||||
db_device.description = device.description
|
||||
db_device.is_gateway = device.is_gateway
|
||||
db_device.is_livebox = device.is_livebox
|
||||
db_device.virt_type = device.virt_type
|
||||
db_device.url = device.url
|
||||
db.query(models.DeviceInterface).filter(
|
||||
models.DeviceInterface.device_id == device_id
|
||||
).delete()
|
||||
for iface in device.interfaces:
|
||||
db.add(models.DeviceInterface(device_id=device_id, **iface.model_dump()))
|
||||
db.commit()
|
||||
db.refresh(db_device)
|
||||
return db_device
|
||||
|
||||
|
||||
@router.delete("/{device_id}")
|
||||
def delete_device(device_id: int, db: Session = Depends(get_db)):
|
||||
db_device = db.query(models.Device).filter(models.Device.id == device_id).first()
|
||||
if not db_device:
|
||||
raise HTTPException(status_code=404, detail="Équipement introuvable")
|
||||
db.delete(db_device)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
@@ -0,0 +1,156 @@
|
||||
import ipaddress
|
||||
import subprocess
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Optional
|
||||
|
||||
import dns.resolver
|
||||
import dns.reversename
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_HOSTS_PER_TARGET = 1024 # refuse les /21 et plus larges
|
||||
MAX_HOSTS_TOTAL = 4096 # cap global sur l'ensemble des targets
|
||||
|
||||
|
||||
class ScanTarget(BaseModel):
|
||||
vlan_id: int
|
||||
cidr: str
|
||||
|
||||
|
||||
class ScanRequest(BaseModel):
|
||||
dns_server: str = "8.8.8.8"
|
||||
targets: list[ScanTarget]
|
||||
|
||||
@field_validator("dns_server")
|
||||
@classmethod
|
||||
def _dns_server(cls, v: str) -> str:
|
||||
try:
|
||||
ipaddress.ip_address(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"dns_server must be a valid IP address, got: {v!r}")
|
||||
return v
|
||||
|
||||
|
||||
class DiscoveredHost(BaseModel):
|
||||
ip: str
|
||||
hostname: Optional[str] = None
|
||||
vlan_id: int
|
||||
cidr: str
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
hosts: list[DiscoveredHost]
|
||||
total_scanned: int
|
||||
duration_s: float
|
||||
|
||||
|
||||
def _ping(ip: str) -> bool:
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["ping", "-c", "1", "-W", "1", ip],
|
||||
capture_output=True,
|
||||
timeout=3,
|
||||
)
|
||||
return r.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _ptr_lookup(ip: str, nameserver: str) -> Optional[str]:
|
||||
try:
|
||||
resolver = dns.resolver.Resolver(configure=False)
|
||||
resolver.nameservers = [nameserver]
|
||||
resolver.timeout = 1
|
||||
resolver.lifetime = 2
|
||||
rev = dns.reversename.from_address(ip)
|
||||
ans = resolver.resolve(rev, "PTR")
|
||||
return str(ans[0]).rstrip(".")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _scan_one(ip: str, dns_server: str, vlan_id: int, cidr: str) -> Optional[DiscoveredHost]:
|
||||
if not _ping(ip):
|
||||
return None
|
||||
hostname = _ptr_lookup(ip, dns_server)
|
||||
return DiscoveredHost(ip=ip, hostname=hostname, vlan_id=vlan_id, cidr=cidr)
|
||||
|
||||
|
||||
class PingRequest(BaseModel):
|
||||
ips: list[str]
|
||||
|
||||
@field_validator("ips")
|
||||
@classmethod
|
||||
def _ips(cls, v: list[str]) -> list[str]:
|
||||
for ip in v:
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid IP address: {ip!r}")
|
||||
return v
|
||||
|
||||
|
||||
class PingResult(BaseModel):
|
||||
ip: str
|
||||
alive: bool
|
||||
|
||||
|
||||
@router.post("/ping", response_model=list[PingResult])
|
||||
def ping_many(req: PingRequest):
|
||||
if not req.ips:
|
||||
return []
|
||||
with ThreadPoolExecutor(max_workers=50) as pool:
|
||||
futures = {pool.submit(_ping, ip): ip for ip in req.ips}
|
||||
results = [PingResult(ip=futures[f], alive=f.result()) for f in as_completed(futures)]
|
||||
return results
|
||||
|
||||
|
||||
@router.post("/scan", response_model=ScanResponse)
|
||||
def scan(req: ScanRequest):
|
||||
tasks: list[tuple[str, str, int, str]] = []
|
||||
|
||||
for t in req.targets:
|
||||
try:
|
||||
net = ipaddress.ip_network(t.cidr, strict=False)
|
||||
except ValueError:
|
||||
raise HTTPException(400, f"CIDR invalide : {t.cidr}")
|
||||
|
||||
hosts = list(net.hosts())
|
||||
if len(hosts) > MAX_HOSTS_PER_TARGET:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"Réseau {t.cidr} trop large ({len(hosts)} hôtes). "
|
||||
f"Maximum par target : {MAX_HOSTS_PER_TARGET} hôtes (/22 ou plus petit).",
|
||||
)
|
||||
for ip in hosts:
|
||||
tasks.append((str(ip), req.dns_server, t.vlan_id, t.cidr))
|
||||
|
||||
if not tasks:
|
||||
raise HTTPException(400, "Aucune cible à scanner.")
|
||||
|
||||
if len(tasks) > MAX_HOSTS_TOTAL:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"Trop d'hôtes au total ({len(tasks)}). Maximum global : {MAX_HOSTS_TOTAL}.",
|
||||
)
|
||||
|
||||
t0 = time.time()
|
||||
results: list[DiscoveredHost] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=100) as pool:
|
||||
futures = [pool.submit(_scan_one, *args) for args in tasks]
|
||||
for f in as_completed(futures):
|
||||
host = f.result()
|
||||
if host:
|
||||
results.append(host)
|
||||
|
||||
results.sort(key=lambda h: ipaddress.ip_address(h.ip))
|
||||
|
||||
return ScanResponse(
|
||||
hosts=results,
|
||||
total_scanned=len(tasks),
|
||||
duration_s=round(time.time() - t0, 1),
|
||||
)
|
||||
@@ -0,0 +1,102 @@
|
||||
import ipaddress
|
||||
import re
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
import models
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_COLOR_RE = re.compile(r"^#[0-9a-fA-F]{6}$")
|
||||
|
||||
|
||||
class VlanCreate(BaseModel):
|
||||
vlan_id: Optional[int] = None
|
||||
name: str
|
||||
cidr: Optional[str] = ""
|
||||
color: str = "#4A90D9"
|
||||
|
||||
@field_validator("vlan_id")
|
||||
@classmethod
|
||||
def _vlan_id(cls, v: Optional[int]) -> Optional[int]:
|
||||
if v is not None and not (1 <= v <= 4094):
|
||||
raise ValueError("vlan_id must be between 1 and 4094")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def _name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("name cannot be empty")
|
||||
if len(v) > 100:
|
||||
raise ValueError("name too long (max 100 characters)")
|
||||
return v
|
||||
|
||||
@field_validator("cidr")
|
||||
@classmethod
|
||||
def _cidr(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v:
|
||||
try:
|
||||
ipaddress.ip_network(v, strict=False)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid CIDR notation: {v!r}")
|
||||
return v
|
||||
|
||||
@field_validator("color")
|
||||
@classmethod
|
||||
def _color(cls, v: str) -> str:
|
||||
if not _COLOR_RE.match(v):
|
||||
raise ValueError("color must be a 6-digit hex color (e.g. #4A90D9)")
|
||||
return v
|
||||
|
||||
|
||||
class VlanOut(VlanCreate):
|
||||
id: int
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
@router.get("/", response_model=List[VlanOut])
|
||||
def list_vlans(db: Session = Depends(get_db)):
|
||||
return db.query(models.Vlan).order_by(nullsfirst(models.Vlan.vlan_id)).all()
|
||||
|
||||
|
||||
@router.post("/", response_model=VlanOut)
|
||||
def create_vlan(vlan: VlanCreate, db: Session = Depends(get_db)):
|
||||
if vlan.vlan_id is not None:
|
||||
existing = db.query(models.Vlan).filter(models.Vlan.vlan_id == vlan.vlan_id).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"VLAN {vlan.vlan_id} existe déjà")
|
||||
db_vlan = models.Vlan(**vlan.model_dump())
|
||||
db.add(db_vlan)
|
||||
db.commit()
|
||||
db.refresh(db_vlan)
|
||||
return db_vlan
|
||||
|
||||
|
||||
@router.put("/{vlan_pk}", response_model=VlanOut)
|
||||
def update_vlan(vlan_pk: int, vlan: VlanCreate, db: Session = Depends(get_db)):
|
||||
db_vlan = db.query(models.Vlan).filter(models.Vlan.id == vlan_pk).first()
|
||||
if not db_vlan:
|
||||
raise HTTPException(status_code=404, detail="VLAN introuvable")
|
||||
for k, v in vlan.model_dump().items():
|
||||
setattr(db_vlan, k, v)
|
||||
db.commit()
|
||||
db.refresh(db_vlan)
|
||||
return db_vlan
|
||||
|
||||
|
||||
@router.delete("/{vlan_pk}")
|
||||
def delete_vlan(vlan_pk: int, db: Session = Depends(get_db)):
|
||||
db_vlan = db.query(models.Vlan).filter(models.Vlan.id == vlan_pk).first()
|
||||
if not db_vlan:
|
||||
raise HTTPException(status_code=404, detail="VLAN introuvable")
|
||||
db.delete(db_vlan)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Configures a fresh in-memory SQLite database for every test session.
|
||||
DATABASE_URL must be set before any app module is imported.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
# Must be set before importing database or main
|
||||
_tmpdb = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
_tmpdb.close()
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{_tmpdb.name}"
|
||||
os.environ.setdefault("SECRET_KEY", "test-only-secret-key-not-for-production")
|
||||
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Tests de sécurité pour l'authentification.
|
||||
|
||||
Couvre :
|
||||
- SEC-FIX-001 : bootstrap admin, rattrapage admin existant, blocage CRUD avant changement
|
||||
- SEC-FIX-002 : rate limiting login
|
||||
- SEC-FIX-003 : validation mot de passe et username
|
||||
- SEC-FIX-004 : invalidation de token après changement de mot de passe
|
||||
"""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import text
|
||||
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# conftest.py sets DATABASE_URL before this import
|
||||
from database import engine, Base
|
||||
from main import (
|
||||
app,
|
||||
_migrate_users_must_change_password,
|
||||
_migrate_users_token_version,
|
||||
_migrate_force_admin_password_change,
|
||||
_migrate_users,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_db():
|
||||
"""Fresh schema + seeded admin for each test."""
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_migrate_users_must_change_password()
|
||||
_migrate_users_token_version()
|
||||
_migrate_users()
|
||||
yield
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_rate_limits():
|
||||
"""Remet à zéro les compteurs de rate limiting entre chaque test."""
|
||||
from routers.auth import _ip_attempts, _login_attempts, _rate_lock
|
||||
with _rate_lock:
|
||||
_ip_attempts.clear()
|
||||
_login_attempts.clear()
|
||||
yield
|
||||
with _rate_lock:
|
||||
_ip_attempts.clear()
|
||||
_login_attempts.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app, raise_server_exceptions=True)
|
||||
|
||||
|
||||
def _login(client, username="admin", password="admin"):
|
||||
return client.post(
|
||||
"/api/auth/login",
|
||||
data={"username": username, "password": password},
|
||||
)
|
||||
|
||||
|
||||
def _auth_headers(token):
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-FIX-001 — Bootstrap et rattrapage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBootstrap:
|
||||
def test_fresh_db_admin_must_change_password(self, client):
|
||||
"""Nouvelle base : admin créé avec must_change_password=1."""
|
||||
r = _login(client)
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["must_change_password"] is True
|
||||
assert data["username"] == "admin"
|
||||
|
||||
def test_crud_blocked_before_password_change(self, client):
|
||||
"""CRUD refusé (403) tant que must_change_password est vrai."""
|
||||
token = _login(client).json()["access_token"]
|
||||
r = client.get("/api/vlans/", headers=_auth_headers(token))
|
||||
assert r.status_code == 403
|
||||
assert r.json()["detail"] == "Password change required"
|
||||
|
||||
def test_crud_allowed_after_password_change(self, client):
|
||||
"""CRUD autorisé après changement de mot de passe."""
|
||||
token = _login(client).json()["access_token"]
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_password": "SecurePass1"},
|
||||
headers=_auth_headers(token),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
new_token = r.json()["access_token"]
|
||||
assert r.json()["must_change_password"] is False
|
||||
r2 = client.get("/api/vlans/", headers=_auth_headers(new_token))
|
||||
assert r2.status_code == 200
|
||||
|
||||
def test_migration_forces_existing_admin_with_default_password(self, client):
|
||||
"""Rattrapage : admin existant avec must_change_password=0 et password 'admin' est forcé."""
|
||||
# Simuler une ancienne base : admin avec must_change_password=0
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("UPDATE users SET must_change_password=0 WHERE username='admin'"))
|
||||
conn.commit()
|
||||
# La migration de rattrapage doit remettre must_change_password=1
|
||||
_migrate_force_admin_password_change()
|
||||
r = _login(client)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["must_change_password"] is True
|
||||
|
||||
def test_migration_does_not_touch_admin_with_custom_password(self, client):
|
||||
"""Rattrapage : admin avec mot de passe personnalisé et must_change_password=0 n'est pas touché."""
|
||||
from passlib.context import CryptContext
|
||||
pwd = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(
|
||||
"UPDATE users SET hashed_password=:h, must_change_password=0 WHERE username='admin'"
|
||||
), {"h": pwd.hash("CustomPass9")})
|
||||
conn.commit()
|
||||
_migrate_force_admin_password_change()
|
||||
r = client.post("/api/auth/login", data={"username": "admin", "password": "CustomPass9"})
|
||||
assert r.status_code == 200
|
||||
assert r.json()["must_change_password"] is False
|
||||
|
||||
def test_initial_admin_password_env_var(self, monkeypatch):
|
||||
"""Avec INITIAL_ADMIN_PASSWORD, must_change_password=0."""
|
||||
monkeypatch.setenv("INITIAL_ADMIN_PASSWORD", "EnvPass42")
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_migrate_users_must_change_password()
|
||||
_migrate_users_token_version()
|
||||
_migrate_users()
|
||||
with TestClient(app) as c:
|
||||
r = c.post("/api/auth/login", data={"username": "admin", "password": "EnvPass42"})
|
||||
assert r.status_code == 200
|
||||
assert r.json()["must_change_password"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-FIX-004 — Invalidation de token après changement de mot de passe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenInvalidation:
|
||||
def test_old_token_rejected_after_password_change(self, client):
|
||||
"""L'ancien token est invalide après changement de mot de passe."""
|
||||
# Forcer must_change_password=0 pour pouvoir tester le CRUD
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("UPDATE users SET must_change_password=0 WHERE username='admin'"))
|
||||
conn.commit()
|
||||
old_token = _login(client).json()["access_token"]
|
||||
# Changer le mot de passe → invalide old_token
|
||||
client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_password": "NewPass99"},
|
||||
headers=_auth_headers(old_token),
|
||||
)
|
||||
r = client.get("/api/vlans/", headers=_auth_headers(old_token))
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_new_token_valid_after_password_change(self, client):
|
||||
"""Le nouveau token fonctionne après changement de mot de passe."""
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("UPDATE users SET must_change_password=0 WHERE username='admin'"))
|
||||
conn.commit()
|
||||
old_token = _login(client).json()["access_token"]
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_password": "NewPass99"},
|
||||
headers=_auth_headers(old_token),
|
||||
)
|
||||
new_token = r.json()["access_token"]
|
||||
r2 = client.get("/api/vlans/", headers=_auth_headers(new_token))
|
||||
assert r2.status_code == 200
|
||||
|
||||
def test_token_without_version_accepted_for_backward_compat(self, client):
|
||||
"""Token sans champ 'ver' (ancien format) est accepté : ver absent → ver=1 par défaut."""
|
||||
from jose import jwt as jose_jwt
|
||||
from routers.auth import SECRET_KEY, ALGORITHM
|
||||
from datetime import datetime, timedelta, timezone
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
old_format_token = jose_jwt.encode(
|
||||
{"sub": "admin", "exp": expire},
|
||||
SECRET_KEY,
|
||||
algorithm=ALGORITHM,
|
||||
)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("UPDATE users SET must_change_password=0 WHERE username='admin'"))
|
||||
conn.commit()
|
||||
r = client.get("/api/auth/me", headers=_auth_headers(old_format_token))
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_token_with_wrong_version_rejected(self, client):
|
||||
"""Token avec version incorrecte est rejeté."""
|
||||
from jose import jwt as jose_jwt
|
||||
from routers.auth import SECRET_KEY, ALGORITHM
|
||||
from datetime import datetime, timedelta, timezone
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
bad_token = jose_jwt.encode(
|
||||
{"sub": "admin", "ver": 999, "exp": expire},
|
||||
SECRET_KEY,
|
||||
algorithm=ALGORITHM,
|
||||
)
|
||||
r = client.get("/api/auth/me", headers=_auth_headers(bad_token))
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-FIX-003 — Validation mot de passe et username
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidation:
|
||||
def _get_valid_token(self, client):
|
||||
"""Retourne un token valide (must_change_password forcé à 0)."""
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("UPDATE users SET must_change_password=0 WHERE username='admin'"))
|
||||
conn.commit()
|
||||
return _login(client).json()["access_token"]
|
||||
|
||||
def test_password_too_short_rejected(self, client):
|
||||
token = self._get_valid_token(client)
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_password": "Short1"},
|
||||
headers=_auth_headers(token),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert r.json()["detail"] == "password_too_short"
|
||||
|
||||
def test_password_no_digit_rejected(self, client):
|
||||
token = self._get_valid_token(client)
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_password": "OnlyLetters"},
|
||||
headers=_auth_headers(token),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert r.json()["detail"] == "password_too_weak"
|
||||
|
||||
def test_password_no_letter_rejected(self, client):
|
||||
token = self._get_valid_token(client)
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_password": "12345678"},
|
||||
headers=_auth_headers(token),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert r.json()["detail"] == "password_too_weak"
|
||||
|
||||
def test_valid_password_accepted(self, client):
|
||||
token = self._get_valid_token(client)
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_password": "ValidPass1"},
|
||||
headers=_auth_headers(token),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_username_invalid_chars_rejected(self, client):
|
||||
token = self._get_valid_token(client)
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_username": "bad user!"},
|
||||
headers=_auth_headers(token),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert r.json()["detail"] == "username_invalid"
|
||||
|
||||
def test_username_too_long_rejected(self, client):
|
||||
token = self._get_valid_token(client)
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_username": "a" * 65},
|
||||
headers=_auth_headers(token),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert r.json()["detail"] == "username_invalid"
|
||||
|
||||
def test_valid_username_accepted(self, client):
|
||||
token = self._get_valid_token(client)
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "admin", "new_username": "admin_user.1"},
|
||||
headers=_auth_headers(token),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_wrong_current_password_rejected(self, client):
|
||||
token = self._get_valid_token(client)
|
||||
r = client.put(
|
||||
"/api/auth/account",
|
||||
json={"current_password": "wrong", "new_password": "NewPass1!"},
|
||||
headers=_auth_headers(token),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-FIX-002 — Rate limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRateLimit:
|
||||
def test_ip_rate_limit_triggers_429(self, client):
|
||||
"""Après trop de tentatives par IP, le login retourne 429."""
|
||||
from routers.auth import _ip_attempts, _rate_lock, _IP_MAX
|
||||
with _rate_lock:
|
||||
_ip_attempts["testclient"] = [__import__("time").time()] * _IP_MAX
|
||||
r = _login(client)
|
||||
assert r.status_code == 429
|
||||
|
||||
def test_username_rate_limit_triggers_429(self, client):
|
||||
"""Après trop de tentatives par username, le login retourne 429."""
|
||||
from routers.auth import _login_attempts, _rate_lock, _USERNAME_MAX
|
||||
with _rate_lock:
|
||||
_login_attempts["admin"] = [__import__("time").time()] * _USERNAME_MAX
|
||||
r = _login(client)
|
||||
assert r.status_code == 429
|
||||
|
||||
def test_successful_login_clears_username_attempts(self, client):
|
||||
"""Login réussi remet à zéro le compteur username."""
|
||||
from routers.auth import _login_attempts, _rate_lock
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("UPDATE users SET must_change_password=0 WHERE username='admin'"))
|
||||
conn.commit()
|
||||
r = _login(client)
|
||||
assert r.status_code == 200
|
||||
with _rate_lock:
|
||||
assert "admin" not in _login_attempts
|
||||
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Tests de validation — Phase 3
|
||||
|
||||
Couvre :
|
||||
- SEC-FIX-006 : validation des entrées discovery (dns_server, ips, cap global)
|
||||
- SEC-FIX-007 : validators Pydantic sur VlanCreate et DeviceCreate
|
||||
- SEC-FIX-013 : PRAGMA foreign_keys=ON (test indirect via FK constraint)
|
||||
- SEC-FIX-017 : suppression code orphelin Links (endpoint /api/links inexistant)
|
||||
"""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import text
|
||||
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from database import engine, Base
|
||||
from main import app, _migrate_users_must_change_password, _migrate_users_token_version, _migrate_users
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_db():
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_migrate_users_must_change_password()
|
||||
_migrate_users_token_version()
|
||||
_migrate_users()
|
||||
yield
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_rate_limits():
|
||||
from routers.auth import _ip_attempts, _login_attempts, _rate_lock
|
||||
with _rate_lock:
|
||||
_ip_attempts.clear()
|
||||
_login_attempts.clear()
|
||||
yield
|
||||
with _rate_lock:
|
||||
_ip_attempts.clear()
|
||||
_login_attempts.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app, raise_server_exceptions=True)
|
||||
|
||||
|
||||
def _get_token(client):
|
||||
"""Retourne un token admin valide avec must_change_password=0."""
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("UPDATE users SET must_change_password=0 WHERE username='admin'"))
|
||||
conn.commit()
|
||||
r = client.post("/api/auth/login", data={"username": "admin", "password": "admin"})
|
||||
return r.json()["access_token"]
|
||||
|
||||
|
||||
def _auth(token):
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-FIX-007 — Validation VlanCreate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVlanValidation:
|
||||
def test_vlan_id_out_of_range_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/vlans/", json={"name": "Test", "vlan_id": 0}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_vlan_id_max_boundary_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/vlans/", json={"name": "Test", "vlan_id": 4095}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_vlan_id_valid_accepted(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/vlans/", json={"name": "Test", "vlan_id": 100, "color": "#AABBCC"}, headers=_auth(token))
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_vlan_empty_name_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/vlans/", json={"name": " "}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_vlan_invalid_cidr_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/vlans/", json={"name": "Test", "cidr": "not-a-cidr"}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_vlan_valid_cidr_accepted(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/vlans/", json={"name": "Test", "cidr": "192.168.1.0/24", "color": "#AABBCC"}, headers=_auth(token))
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_vlan_invalid_color_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/vlans/", json={"name": "Test", "color": "blue"}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_vlan_valid_color_accepted(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/vlans/", json={"name": "Test", "color": "#1a2B3c"}, headers=_auth(token))
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-FIX-007 — Validation DeviceCreate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeviceValidation:
|
||||
def test_invalid_type_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/devices/", json={"name": "srv", "type": "supercomputer"}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_valid_type_accepted(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/devices/", json={"name": "srv", "type": "server"}, headers=_auth(token))
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_invalid_virt_type_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/devices/", json={"name": "srv", "virt_type": "docker"}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_valid_virt_type_accepted(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/devices/", json={"name": "srv", "virt_type": "lxc"}, headers=_auth(token))
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_invalid_url_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/devices/", json={"name": "srv", "url": "ftp://bad"}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_valid_url_accepted(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/devices/", json={"name": "srv", "url": "https://192.168.1.1:8006"}, headers=_auth(token))
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_empty_name_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/devices/", json={"name": ""}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_invalid_interface_ip_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/devices/", json={
|
||||
"name": "srv",
|
||||
"interfaces": [{"name": "eth0", "ip_address": "not-an-ip"}]
|
||||
}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_valid_interface_ip_accepted(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/devices/", json={
|
||||
"name": "srv",
|
||||
"interfaces": [{"name": "eth0", "ip_address": "10.0.0.1"}]
|
||||
}, headers=_auth(token))
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-FIX-006 — Validation discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDiscoveryValidation:
|
||||
def test_invalid_dns_server_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/discovery/scan", json={
|
||||
"dns_server": "not-an-ip",
|
||||
"targets": [{"vlan_id": 1, "cidr": "192.168.1.0/24"}]
|
||||
}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_invalid_ping_ip_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/discovery/ping", json={"ips": ["1.2.3.4", "bad-ip"]}, headers=_auth(token))
|
||||
assert r.status_code == 422
|
||||
|
||||
def test_valid_ping_ips_accepted(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/discovery/ping", json={"ips": []}, headers=_auth(token))
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_scan_oversized_cidr_rejected(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/discovery/scan", json={
|
||||
"dns_server": "8.8.8.8",
|
||||
"targets": [{"vlan_id": 1, "cidr": "10.0.0.0/8"}]
|
||||
}, headers=_auth(token))
|
||||
assert r.status_code == 400
|
||||
|
||||
def test_scan_global_cap_rejected(self, client):
|
||||
"""Plusieurs targets dont le total dépasse MAX_HOSTS_TOTAL."""
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/discovery/scan", json={
|
||||
"dns_server": "8.8.8.8",
|
||||
"targets": [
|
||||
{"vlan_id": 1, "cidr": "10.0.0.0/22"}, # 1022 hôtes
|
||||
{"vlan_id": 2, "cidr": "10.1.0.0/22"}, # 1022 hôtes
|
||||
{"vlan_id": 3, "cidr": "10.2.0.0/22"}, # 1022 hôtes
|
||||
{"vlan_id": 4, "cidr": "10.3.0.0/22"}, # 1022 hôtes
|
||||
{"vlan_id": 5, "cidr": "10.4.0.0/22"}, # 1022 hôtes → total > 4096
|
||||
]
|
||||
}, headers=_auth(token))
|
||||
assert r.status_code == 400
|
||||
assert "total" in r.json()["detail"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-FIX-017 — Endpoint /api/links absent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLinksEndpointRemoved:
|
||||
def test_links_list_returns_404(self, client):
|
||||
"""Le routeur /api/links a été supprimé en phase 3."""
|
||||
token = _get_token(client)
|
||||
r = client.get("/api/links/", headers=_auth(token))
|
||||
assert r.status_code == 404
|
||||
|
||||
def test_links_create_returns_404(self, client):
|
||||
token = _get_token(client)
|
||||
r = client.post("/api/links/", json={
|
||||
"source_device_id": 1, "target_device_id": 2, "link_type": "trunk"
|
||||
}, headers=_auth(token))
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-FIX-013 — PRAGMA foreign_keys=ON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestForeignKeys:
|
||||
def test_foreign_keys_pragma_is_on(self):
|
||||
"""Vérifie que PRAGMA foreign_keys=ON est actif sur chaque connexion."""
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text("PRAGMA foreign_keys")).fetchone()
|
||||
assert result[0] == 1
|
||||
Reference in New Issue
Block a user