from __future__ import annotations import argparse from pathlib import Path from typing import Any, Iterable from faster_whisper import WhisperModel def load_model( device: str, compute_type: str, device_index: int, ) -> tuple[WhisperModel, str, str, int]: """ Try to load the model on the requested device; if GPU init fails, fall back to CPU. """ try: model = WhisperModel( "model", device=device, compute_type=compute_type, device_index=device_index if device == "cuda" else 0, ) return model, device, compute_type, (device_index if device == "cuda" else 0) except Exception as exc: if device in {"mps", "cuda"}: fallback_device = "cpu" fallback_compute = "int8" print(f"{device.upper()} unavailable, falling back to CPU (reason: {exc})") model = WhisperModel("model", device=fallback_device, compute_type=fallback_compute) return model, fallback_device, fallback_compute, 0 raise def format_timestamp(seconds: float) -> str: total_centiseconds = int(round(seconds * 100)) minutes, remainder = divmod(total_centiseconds, 6000) secs, centiseconds = divmod(remainder, 100) return f"[{minutes:02d}:{secs:02d}.{centiseconds:02d}]" def write_lrc(segments: Iterable, output_path: Path) -> None: lines = [] for seg in segments: text = seg.text.strip() if not text: continue lines.append(f"{format_timestamp(seg.start)}{text}") if not lines: print(f"Skipping empty transcript for {output_path}") return output_path.write_text("\n".join(lines) + "\n", encoding="utf-8") def transcribe_file( model: WhisperModel, audio_path: Path, language: str | None, beam_size: int, vad: bool, vad_parameters: dict | None, extra_generation_args: dict[str, Any] | None = None, ) -> tuple[list, str, float]: segments_iter, info = model.transcribe( str(audio_path), beam_size=beam_size, vad_filter=vad, vad_parameters=vad_parameters if vad else None, language=language, **(extra_generation_args or {}), ) print(f"[{audio_path.name}] Detected language: {info.language} (prob={info.language_probability:.2f})") # Stream segments as they arrive so the console updates in real time. segments = [] for seg in segments_iter: segments.append(seg) print(f"[{seg.start:.2f} -> {seg.end:.2f}] {seg.text}", flush=True) return segments, info.language, info.language_probability def collect_audio_files(path: Path) -> list[Path]: if path.is_file(): return [path] if not path.exists(): raise FileNotFoundError(f"Audio path not found: {path}") return sorted(path.rglob("*.mp3")) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run Whisper inference with local CTranslate2 model.") parser.add_argument( "audio", nargs="?", default="mp3", help="Path to an audio file or directory (default: ./mp3).", ) parser.add_argument("--language", help="Language code (e.g., zh, en). Default: ja.", default="ja") parser.add_argument("--beam-size", type=int, default=5, help="Beam size for decoding.") parser.add_argument("--no-vad", action="store_true", help="Disable VAD (voice activity detection) filtering.") parser.add_argument( "--vad-threshold", type=float, default=0.35, help="Speech probability threshold for VAD. Lower is less aggressive (default: 0.35).", ) parser.add_argument( "--vad-neg-threshold", type=float, default=None, help="Optional silence probability threshold for VAD (useful to smooth speech end).", ) parser.add_argument( "--vad-min-silence-ms", type=int, default=400, help="Minimum silence (ms) to cut a speech chunk when VAD is enabled (default: 400).", ) parser.add_argument( "--vad-pad-ms", type=int, default=500, help="Padding (ms) added before/after each detected speech chunk (default: 500).", ) parser.add_argument( "--device", default="cuda", # ✅ 默认用 CUDA choices=["cpu", "mps", "cuda"], help="Inference device. Use 'cuda' on NVIDIA GPUs, 'mps' on Apple Silicon.", ) parser.add_argument( "--device-index", type=int, default=0, help="CUDA device index (e.g., 0 for the first GPU). Ignored for CPU/MPS.", ) parser.add_argument( "--compute-type", default=None, help=( "Override compute type (e.g., int8, int8_float16, float16). " "Default: CUDA/MPS=float16, CPU=int8." ), ) return parser.parse_args() def main() -> None: args = parse_args() target = Path(args.audio) audio_files = collect_audio_files(target) if not audio_files: print(f"No .mp3 files found under {target}") return compute_type = args.compute_type or ("float16" if args.device in {"cuda", "mps"} else "int8") model, device_used, compute_used, device_index_used = load_model( device=args.device, compute_type=compute_type, device_index=args.device_index, ) if device_used == "cuda": print(f"Using device={device_used}:{device_index_used}, compute_type={compute_used}") else: print(f"Using device={device_used}, compute_type={compute_used}") vad_parameters = { "threshold": args.vad_threshold, "neg_threshold": args.vad_neg_threshold, "min_silence_duration_ms": args.vad_min_silence_ms, "speech_pad_ms": args.vad_pad_ms, } if args.no_vad: vad_parameters = None else: print( "VAD enabled: " f"threshold={vad_parameters['threshold']}, " f"neg_threshold={vad_parameters['neg_threshold']}, " f"min_silence_ms={vad_parameters['min_silence_duration_ms']}, " f"pad_ms={vad_parameters['speech_pad_ms']}" ) generation_args: dict[str, Any] = { # "max_initial_timestamp": 10, "repetition_penalty": 1.1, } for idx, audio_path in enumerate(audio_files, start=1): print(f"\n[{idx}/{len(audio_files)}] Processing {audio_path}") try: segments, _, _ = transcribe_file( model=model, audio_path=audio_path, language=args.language, beam_size=args.beam_size, vad=not args.no_vad, vad_parameters=vad_parameters, extra_generation_args=generation_args, ) write_lrc(segments, audio_path.with_suffix(".lrc")) except Exception as exc: print(f"Failed to process {audio_path}: {exc}") if __name__ == "__main__": main()