Files
stupid-simple-network-inven…/backend/routers/discovery.py
T
olivier e8ca10f1b7 fix: cap /api/discovery/ping at 4096 IPs and fix test suite
- Add MAX_PING_IPS=4096 constant and validate list size in PingRequest
  before spawning futures, returning 422 on overflow
- Add test_ping_too_many_ips_rejected to cover the new cap
- Pin httpx<0.28 in requirements-test.txt (0.28 broke TestClient API)
- Fix reset_db fixture to set a known admin password regardless of
  INITIAL_ADMIN_PASSWORD env var (was causing 401 on all auth tests)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-18 18:16:08 +02:00

214 lines
6.4 KiB
Python

import errno
import ipaddress
import os
import socket
import subprocess
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional
import dns.resolver
import dns.reversename
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, field_validator
router = APIRouter()
MAX_HOSTS_PER_TARGET = 1024 # refuse les /21 et plus larges
MAX_HOSTS_TOTAL = 4096 # cap global sur l'ensemble des targets
MAX_PING_IPS = 4096 # cap sur /api/discovery/ping
_ENV_DNS = os.environ.get("DNS_SERVER", "").strip()
class ScanTarget(BaseModel):
vlan_id: int
cidr: str
class ScanRequest(BaseModel):
dns_server: str = ""
targets: list[ScanTarget]
tcp_check: bool = False
soft_scan: bool = False
@field_validator("dns_server")
@classmethod
def _dns_server(cls, v: str) -> str:
if v:
try:
ipaddress.ip_address(v)
except ValueError:
raise ValueError(f"dns_server must be a valid IP address, got: {v!r}")
return v
@router.get("/config")
def get_config():
return {"dns_server": _ENV_DNS}
class DiscoveredHost(BaseModel):
ip: str
hostname: Optional[str] = None
vlan_id: int
cidr: str
class ScanResponse(BaseModel):
hosts: list[DiscoveredHost]
total_scanned: int
duration_s: float
def _ping(ip: str) -> bool:
try:
r = subprocess.run(
["ping", "-c", "1", "-W", "1", ip],
capture_output=True,
timeout=3,
)
if r.returncode != 0:
return False
# Guard against proxy-ARP / gateway false positives: verify the ICMP
# reply actually came from the target IP and not an intermediate node.
stdout = r.stdout.decode(errors="ignore")
# Also handles "from hostname (ip):" format when DNS resolves the target
return f"from {ip}:" in stdout or f"from {ip} " in stdout or f"({ip}):" in stdout
except Exception:
return False
def _ptr_lookup(ip: str, nameserver: str) -> Optional[str]:
try:
resolver = dns.resolver.Resolver(configure=False)
resolver.nameservers = [nameserver]
resolver.timeout = 1
resolver.lifetime = 2
rev = dns.reversename.from_address(ip)
ans = resolver.resolve(rev, "PTR")
return str(ans[0]).rstrip(".")
except Exception:
return None
_TCP_PROBE_PORTS = (22, 80, 443, 8080, 8443)
_TCP_PROBE_TIMEOUT = 0.5 # seconds per port
def _tcp_check(ip: str) -> bool:
# Secondary check after ICMP: some gateways (e.g. UniFi) respond to ICMP
# for every IP in the subnet via proxy-ARP, spoofing the source IP so the
# source-IP guard in _ping() cannot help. A real host will reply to TCP
# (RST = port closed, or accept = port open); a ghost IP gets its SYN
# dropped by the gateway → timeout → False.
for port in _TCP_PROBE_PORTS:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(_TCP_PROBE_TIMEOUT)
try:
err = sock.connect_ex((ip, port))
if err == 0 or err == errno.ECONNREFUSED:
return True
except OSError:
pass
finally:
sock.close()
return False
def _scan_one(ip: str, dns_server: str, vlan_id: int, cidr: str, tcp_check: bool = False) -> Optional[DiscoveredHost]:
if tcp_check:
# TCP-only mode: bypasses ICMP entirely.
# Proxy-ARP gateways never spoof TCP replies, so ghost IPs are filtered
# without ICMP. Also catches hosts whose ICMP is rate-limited under load.
if not _tcp_check(ip):
return None
else:
if not _ping(ip):
return None
hostname = _ptr_lookup(ip, dns_server) if dns_server else None
return DiscoveredHost(ip=ip, hostname=hostname, vlan_id=vlan_id, cidr=cidr)
class PingRequest(BaseModel):
ips: list[str]
@field_validator("ips")
@classmethod
def _ips(cls, v: list[str]) -> list[str]:
if len(v) > MAX_PING_IPS:
raise ValueError(f"Too many IPs: {len(v)} (max {MAX_PING_IPS})")
for ip in v:
try:
ipaddress.ip_address(ip)
except ValueError:
raise ValueError(f"Invalid IP address: {ip!r}")
return v
class PingResult(BaseModel):
ip: str
alive: bool
@router.post("/ping", response_model=list[PingResult])
def ping_many(req: PingRequest):
if not req.ips:
return []
with ThreadPoolExecutor(max_workers=50) as pool:
futures = {pool.submit(_ping, ip): ip for ip in req.ips}
results = [PingResult(ip=futures[f], alive=f.result()) for f in as_completed(futures)]
return results
@router.post("/scan", response_model=ScanResponse)
def scan(req: ScanRequest):
tasks: list[tuple[str, str, int, str]] = []
for t in req.targets:
try:
net = ipaddress.ip_network(t.cidr, strict=False)
except ValueError:
raise HTTPException(400, f"CIDR invalide : {t.cidr}")
hosts = list(net.hosts())
if len(hosts) > MAX_HOSTS_PER_TARGET:
raise HTTPException(
400,
f"Réseau {t.cidr} trop large ({len(hosts)} hôtes). "
f"Maximum par target : {MAX_HOSTS_PER_TARGET} hôtes (/22 ou plus petit).",
)
for ip in hosts:
tasks.append((str(ip), req.dns_server, t.vlan_id, t.cidr))
if not tasks:
raise HTTPException(400, "Aucune cible à scanner.")
if len(tasks) > MAX_HOSTS_TOTAL:
raise HTTPException(
400,
f"Trop d'hôtes au total ({len(tasks)}). Maximum global : {MAX_HOSTS_TOTAL}.",
)
t0 = time.time()
results: list[DiscoveredHost] = []
# Soft scan reduces ICMP concurrency to avoid rate-limiting on switches/APs.
# Has no effect in tcp_check mode (TCP probes are not rate-limited the same way).
workers = 10 if (req.soft_scan and not req.tcp_check) else 100
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = [pool.submit(_scan_one, *args, tcp_check=req.tcp_check) for args in tasks]
for f in as_completed(futures):
host = f.result()
if host:
results.append(host)
results.sort(key=lambda h: ipaddress.ip_address(h.ip))
return ScanResponse(
hosts=results,
total_scanned=len(tasks),
duration_s=round(time.time() - t0, 1),
)