class ModelManager:
def __init__(self, config: AppConfig, *, resolver: PresetResolver = get_preset) -> None:
self._config = config
self._resolver = resolver
self._cache: OrderedDict[str, CachedModel] = OrderedDict()
self._loading: dict[str, Event] = {}
self._lock = Lock()
self._semaphore = BoundedSemaphore(config.max_concurrent_transcriptions)
self._hits = 0
self._misses = 0
def load(self, runtime: RuntimeOptions) -> ModelLoadResult:
started = time.perf_counter()
while True:
now = time.time()
with self._lock:
self._unload_idle_locked(now)
cached = self._cache.pop(runtime.cache_key, None)
if cached is not None:
cached.last_used_at = now
cached.hits += 1
self._cache[runtime.cache_key] = cached
self._hits += 1
return ModelLoadResult(
transcriber=cached.transcriber,
cache_hit=True,
load_seconds=time.perf_counter() - started,
)
loading = self._loading.get(runtime.cache_key)
if loading is None:
loading = Event()
self._loading[runtime.cache_key] = loading
break
loading.wait()
preset = self._resolver(runtime.preset)
overrides: dict[str, object] = {}
if preset.backend == "faster-whisper":
overrides = {
"cpu_threads": runtime.cpu_threads,
"num_workers": runtime.num_workers,
}
try:
transcriber = create_transcriber_from_preset(preset, **overrides)
except Exception:
with self._lock:
loading = self._loading.pop(runtime.cache_key, None)
if loading is not None:
loading.set()
raise
with self._lock:
loaded_at = time.time()
self._misses += 1
self._cache[runtime.cache_key] = CachedModel(
runtime=runtime,
transcriber=transcriber,
loaded_at=loaded_at,
last_used_at=loaded_at,
)
while len(self._cache) > self._config.max_loaded_models:
self._cache.popitem(last=False)
reclaim_released_memory()
loading = self._loading.pop(runtime.cache_key, None)
if loading is not None:
loading.set()
return ModelLoadResult(
transcriber=transcriber,
cache_hit=False,
load_seconds=time.perf_counter() - started,
)
def transcribe(
self,
runtime: RuntimeOptions,
audio_path: str,
options: TranscriptionOptions,
) -> tuple[TranscriptionResult, dict[str, object]]:
load_result = self.load(runtime)
started = time.perf_counter()
with self._semaphore:
result = load_result.transcriber.transcribe(audio_path, options)
transcribe_seconds = time.perf_counter() - started
return result, {
"cache_hit": load_result.cache_hit,
"load_seconds": round(load_result.load_seconds, 3),
"transcribe_seconds": round(transcribe_seconds, 3),
"total_seconds": round(load_result.load_seconds + transcribe_seconds, 3),
"runtime": runtime.to_dict(),
}
def unload_all(self) -> dict[str, object]:
with self._lock:
entries = [cached.runtime.to_dict() for cached in self._cache.values()]
self._cache.clear()
return {
"unloaded": len(entries),
"entries": entries,
"native_trim_attempted": reclaim_released_memory(),
}
def update_config(self, config: AppConfig) -> None:
with self._lock:
self._config = config
while len(self._cache) > self._config.max_loaded_models:
self._cache.popitem(last=False)
reclaim_released_memory()
def unload_idle(self) -> int:
with self._lock:
return self._unload_idle_locked(time.time())
def status(self) -> dict[str, object]:
now = time.time()
with self._lock:
entries = [
{
**cached.runtime.to_dict(),
"loaded_for_seconds": round(now - cached.loaded_at, 3),
"idle_for_seconds": round(now - cached.last_used_at, 3),
"hits": cached.hits,
}
for cached in self._cache.values()
]
return {
"loaded": len(entries),
"max_loaded": self._config.max_loaded_models,
"entries": entries,
"hits": self._hits,
"misses": self._misses,
"idle_ttl_seconds": self._config.model_idle_ttl_seconds,
"process_rss_bytes": current_rss_bytes(),
}
def _unload_idle_locked(self, now: float) -> int:
ttl = self._config.model_idle_ttl_seconds
if ttl <= 0:
return 0
stale = [
cache_key
for cache_key, cached in self._cache.items()
if now - cached.last_used_at >= ttl
]
for cache_key in stale:
self._cache.pop(cache_key, None)
if stale:
reclaim_released_memory()
return len(stale)