211 lines
6.9 KiB
Python
211 lines
6.9 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Iterable
|
||
|
|
|
||
|
|
from faster_whisper import WhisperModel
|
||
|
|
|
||
|
|
|
||
|
|
def load_model(
|
||
|
|
device: str,
|
||
|
|
compute_type: str,
|
||
|
|
device_index: int,
|
||
|
|
) -> tuple[WhisperModel, str, str, int]:
|
||
|
|
"""
|
||
|
|
Try to load the model on the requested device; if GPU init fails, fall back to CPU.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
model = WhisperModel(
|
||
|
|
"model",
|
||
|
|
device=device,
|
||
|
|
compute_type=compute_type,
|
||
|
|
device_index=device_index if device == "cuda" else 0,
|
||
|
|
)
|
||
|
|
return model, device, compute_type, (device_index if device == "cuda" else 0)
|
||
|
|
except Exception as exc:
|
||
|
|
if device in {"mps", "cuda"}:
|
||
|
|
fallback_device = "cpu"
|
||
|
|
fallback_compute = "int8"
|
||
|
|
print(f"{device.upper()} unavailable, falling back to CPU (reason: {exc})")
|
||
|
|
model = WhisperModel("model", device=fallback_device, compute_type=fallback_compute)
|
||
|
|
return model, fallback_device, fallback_compute, 0
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
def format_timestamp(seconds: float) -> str:
|
||
|
|
total_centiseconds = int(round(seconds * 100))
|
||
|
|
minutes, remainder = divmod(total_centiseconds, 6000)
|
||
|
|
secs, centiseconds = divmod(remainder, 100)
|
||
|
|
return f"[{minutes:02d}:{secs:02d}.{centiseconds:02d}]"
|
||
|
|
|
||
|
|
|
||
|
|
def write_lrc(segments: Iterable, output_path: Path) -> None:
|
||
|
|
lines = []
|
||
|
|
for seg in segments:
|
||
|
|
text = seg.text.strip()
|
||
|
|
if not text:
|
||
|
|
continue
|
||
|
|
lines.append(f"{format_timestamp(seg.start)}{text}")
|
||
|
|
if not lines:
|
||
|
|
print(f"Skipping empty transcript for {output_path}")
|
||
|
|
return
|
||
|
|
output_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||
|
|
|
||
|
|
|
||
|
|
def transcribe_file(
|
||
|
|
model: WhisperModel,
|
||
|
|
audio_path: Path,
|
||
|
|
language: str | None,
|
||
|
|
beam_size: int,
|
||
|
|
vad: bool,
|
||
|
|
vad_parameters: dict | None,
|
||
|
|
extra_generation_args: dict[str, Any] | None = None,
|
||
|
|
) -> tuple[list, str, float]:
|
||
|
|
segments_iter, info = model.transcribe(
|
||
|
|
str(audio_path),
|
||
|
|
beam_size=beam_size,
|
||
|
|
vad_filter=vad,
|
||
|
|
vad_parameters=vad_parameters if vad else None,
|
||
|
|
language=language,
|
||
|
|
**(extra_generation_args or {}),
|
||
|
|
)
|
||
|
|
print(f"[{audio_path.name}] Detected language: {info.language} (prob={info.language_probability:.2f})")
|
||
|
|
|
||
|
|
# Stream segments as they arrive so the console updates in real time.
|
||
|
|
segments = []
|
||
|
|
for seg in segments_iter:
|
||
|
|
segments.append(seg)
|
||
|
|
print(f"[{seg.start:.2f} -> {seg.end:.2f}] {seg.text}", flush=True)
|
||
|
|
|
||
|
|
return segments, info.language, info.language_probability
|
||
|
|
|
||
|
|
|
||
|
|
def collect_audio_files(path: Path) -> list[Path]:
|
||
|
|
if path.is_file():
|
||
|
|
return [path]
|
||
|
|
if not path.exists():
|
||
|
|
raise FileNotFoundError(f"Audio path not found: {path}")
|
||
|
|
return sorted(path.rglob("*.mp3"))
|
||
|
|
|
||
|
|
|
||
|
|
def parse_args() -> argparse.Namespace:
|
||
|
|
parser = argparse.ArgumentParser(description="Run Whisper inference with local CTranslate2 model.")
|
||
|
|
parser.add_argument(
|
||
|
|
"audio",
|
||
|
|
nargs="?",
|
||
|
|
default="mp3",
|
||
|
|
help="Path to an audio file or directory (default: ./mp3).",
|
||
|
|
)
|
||
|
|
parser.add_argument("--language", help="Language code (e.g., zh, en). Default: ja.", default="ja")
|
||
|
|
parser.add_argument("--beam-size", type=int, default=5, help="Beam size for decoding.")
|
||
|
|
parser.add_argument("--no-vad", action="store_true", help="Disable VAD (voice activity detection) filtering.")
|
||
|
|
parser.add_argument(
|
||
|
|
"--vad-threshold",
|
||
|
|
type=float,
|
||
|
|
default=0.35,
|
||
|
|
help="Speech probability threshold for VAD. Lower is less aggressive (default: 0.35).",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--vad-neg-threshold",
|
||
|
|
type=float,
|
||
|
|
default=None,
|
||
|
|
help="Optional silence probability threshold for VAD (useful to smooth speech end).",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--vad-min-silence-ms",
|
||
|
|
type=int,
|
||
|
|
default=400,
|
||
|
|
help="Minimum silence (ms) to cut a speech chunk when VAD is enabled (default: 400).",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--vad-pad-ms",
|
||
|
|
type=int,
|
||
|
|
default=500,
|
||
|
|
help="Padding (ms) added before/after each detected speech chunk (default: 500).",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--device",
|
||
|
|
default="cuda", # ✅ 默认用 CUDA
|
||
|
|
choices=["cpu", "mps", "cuda"],
|
||
|
|
help="Inference device. Use 'cuda' on NVIDIA GPUs, 'mps' on Apple Silicon.",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--device-index",
|
||
|
|
type=int,
|
||
|
|
default=0,
|
||
|
|
help="CUDA device index (e.g., 0 for the first GPU). Ignored for CPU/MPS.",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--compute-type",
|
||
|
|
default=None,
|
||
|
|
help=(
|
||
|
|
"Override compute type (e.g., int8, int8_float16, float16). "
|
||
|
|
"Default: CUDA/MPS=float16, CPU=int8."
|
||
|
|
),
|
||
|
|
)
|
||
|
|
return parser.parse_args()
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
args = parse_args()
|
||
|
|
target = Path(args.audio)
|
||
|
|
|
||
|
|
audio_files = collect_audio_files(target)
|
||
|
|
if not audio_files:
|
||
|
|
print(f"No .mp3 files found under {target}")
|
||
|
|
return
|
||
|
|
|
||
|
|
compute_type = args.compute_type or ("float16" if args.device in {"cuda", "mps"} else "int8")
|
||
|
|
model, device_used, compute_used, device_index_used = load_model(
|
||
|
|
device=args.device,
|
||
|
|
compute_type=compute_type,
|
||
|
|
device_index=args.device_index,
|
||
|
|
)
|
||
|
|
if device_used == "cuda":
|
||
|
|
print(f"Using device={device_used}:{device_index_used}, compute_type={compute_used}")
|
||
|
|
else:
|
||
|
|
print(f"Using device={device_used}, compute_type={compute_used}")
|
||
|
|
|
||
|
|
vad_parameters = {
|
||
|
|
"threshold": args.vad_threshold,
|
||
|
|
"neg_threshold": args.vad_neg_threshold,
|
||
|
|
"min_silence_duration_ms": args.vad_min_silence_ms,
|
||
|
|
"speech_pad_ms": args.vad_pad_ms,
|
||
|
|
}
|
||
|
|
if args.no_vad:
|
||
|
|
vad_parameters = None
|
||
|
|
else:
|
||
|
|
print(
|
||
|
|
"VAD enabled: "
|
||
|
|
f"threshold={vad_parameters['threshold']}, "
|
||
|
|
f"neg_threshold={vad_parameters['neg_threshold']}, "
|
||
|
|
f"min_silence_ms={vad_parameters['min_silence_duration_ms']}, "
|
||
|
|
f"pad_ms={vad_parameters['speech_pad_ms']}"
|
||
|
|
)
|
||
|
|
|
||
|
|
generation_args: dict[str, Any] = {
|
||
|
|
# "max_initial_timestamp": 10,
|
||
|
|
"repetition_penalty": 1.1,
|
||
|
|
}
|
||
|
|
|
||
|
|
for idx, audio_path in enumerate(audio_files, start=1):
|
||
|
|
print(f"\n[{idx}/{len(audio_files)}] Processing {audio_path}")
|
||
|
|
try:
|
||
|
|
segments, _, _ = transcribe_file(
|
||
|
|
model=model,
|
||
|
|
audio_path=audio_path,
|
||
|
|
language=args.language,
|
||
|
|
beam_size=args.beam_size,
|
||
|
|
vad=not args.no_vad,
|
||
|
|
vad_parameters=vad_parameters,
|
||
|
|
extra_generation_args=generation_args,
|
||
|
|
)
|
||
|
|
write_lrc(segments, audio_path.with_suffix(".lrc"))
|
||
|
|
except Exception as exc:
|
||
|
|
print(f"Failed to process {audio_path}: {exc}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|