import ipaddress import subprocess import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Optional import dns.resolver import dns.reversename from fastapi import APIRouter, HTTPException from pydantic import BaseModel, field_validator router = APIRouter() MAX_HOSTS_PER_TARGET = 1024 # refuse les /21 et plus larges MAX_HOSTS_TOTAL = 4096 # cap global sur l'ensemble des targets class ScanTarget(BaseModel): vlan_id: int cidr: str class ScanRequest(BaseModel): dns_server: str = "8.8.8.8" targets: list[ScanTarget] @field_validator("dns_server") @classmethod def _dns_server(cls, v: str) -> str: try: ipaddress.ip_address(v) except ValueError: raise ValueError(f"dns_server must be a valid IP address, got: {v!r}") return v class DiscoveredHost(BaseModel): ip: str hostname: Optional[str] = None vlan_id: int cidr: str class ScanResponse(BaseModel): hosts: list[DiscoveredHost] total_scanned: int duration_s: float def _ping(ip: str) -> bool: try: r = subprocess.run( ["ping", "-c", "1", "-W", "1", ip], capture_output=True, timeout=3, ) return r.returncode == 0 except Exception: return False def _ptr_lookup(ip: str, nameserver: str) -> Optional[str]: try: resolver = dns.resolver.Resolver(configure=False) resolver.nameservers = [nameserver] resolver.timeout = 1 resolver.lifetime = 2 rev = dns.reversename.from_address(ip) ans = resolver.resolve(rev, "PTR") return str(ans[0]).rstrip(".") except Exception: return None def _scan_one(ip: str, dns_server: str, vlan_id: int, cidr: str) -> Optional[DiscoveredHost]: if not _ping(ip): return None hostname = _ptr_lookup(ip, dns_server) return DiscoveredHost(ip=ip, hostname=hostname, vlan_id=vlan_id, cidr=cidr) class PingRequest(BaseModel): ips: list[str] @field_validator("ips") @classmethod def _ips(cls, v: list[str]) -> list[str]: for ip in v: try: ipaddress.ip_address(ip) except ValueError: raise ValueError(f"Invalid IP address: {ip!r}") return v class PingResult(BaseModel): ip: str alive: bool @router.post("/ping", response_model=list[PingResult]) def ping_many(req: PingRequest): if not req.ips: return [] with ThreadPoolExecutor(max_workers=50) as pool: futures = {pool.submit(_ping, ip): ip for ip in req.ips} results = [PingResult(ip=futures[f], alive=f.result()) for f in as_completed(futures)] return results @router.post("/scan", response_model=ScanResponse) def scan(req: ScanRequest): tasks: list[tuple[str, str, int, str]] = [] for t in req.targets: try: net = ipaddress.ip_network(t.cidr, strict=False) except ValueError: raise HTTPException(400, f"CIDR invalide : {t.cidr}") hosts = list(net.hosts()) if len(hosts) > MAX_HOSTS_PER_TARGET: raise HTTPException( 400, f"Réseau {t.cidr} trop large ({len(hosts)} hôtes). " f"Maximum par target : {MAX_HOSTS_PER_TARGET} hôtes (/22 ou plus petit).", ) for ip in hosts: tasks.append((str(ip), req.dns_server, t.vlan_id, t.cidr)) if not tasks: raise HTTPException(400, "Aucune cible à scanner.") if len(tasks) > MAX_HOSTS_TOTAL: raise HTTPException( 400, f"Trop d'hôtes au total ({len(tasks)}). Maximum global : {MAX_HOSTS_TOTAL}.", ) t0 = time.time() results: list[DiscoveredHost] = [] with ThreadPoolExecutor(max_workers=100) as pool: futures = [pool.submit(_scan_one, *args) for args in tasks] for f in as_completed(futures): host = f.result() if host: results.append(host) results.sort(key=lambda h: ipaddress.ip_address(h.ip)) return ScanResponse( hosts=results, total_scanned=len(tasks), duration_s=round(time.time() - t0, 1), )