from __future__ import annotations import argparse import json from pathlib import Path from typing import Any, Iterable import librosa import numpy as np from faster_whisper import WhisperModel class AsmrVadModel: def __init__( self, model_path: Path, force_cpu: bool = False, num_threads: int = 1, ) -> None: try: import onnxruntime as ort except ImportError as exc: raise RuntimeError("onnxruntime is required for --vad-mode asmr") from exc try: from transformers import WhisperFeatureExtractor except ImportError as exc: raise RuntimeError("transformers is required for --vad-mode asmr") from exc metadata_path = model_path.with_name("model_metadata.json") metadata = { "whisper_model_name": "openai/whisper-base", "frame_duration_ms": 20, "total_duration_ms": 30000, } if metadata_path.exists(): metadata.update(json.loads(metadata_path.read_text(encoding="utf-8"))) self.sample_rate = 16000 self.frame_duration_ms = int(metadata.get("frame_duration_ms", 20)) self.chunk_duration_ms = int(metadata.get("total_duration_ms", 30000)) self.chunk_samples = int(self.chunk_duration_ms * self.sample_rate / 1000) opts = ort.SessionOptions() opts.inter_op_num_threads = num_threads opts.intra_op_num_threads = num_threads providers = ["CPUExecutionProvider"] if not force_cpu and "CUDAExecutionProvider" in ort.get_available_providers(): providers.insert(0, "CUDAExecutionProvider") whisper_model_name = metadata.get("whisper_model_name", "openai/whisper-base") local_whisper_path = Path(whisper_model_name) if local_whisper_path.exists(): feature_extractor_source = str(local_whisper_path) elif Path("model").exists(): feature_extractor_source = "model" else: feature_extractor_source = whisper_model_name self.feature_extractor = WhisperFeatureExtractor.from_pretrained(feature_extractor_source) self.session = ort.InferenceSession(str(model_path), providers=providers, sess_options=opts) self.input_name = self.session.get_inputs()[0].name self.output_names = [output.name for output in self.session.get_outputs()] self.providers = self.session.get_providers() def load_audio(self, audio_path: Path) -> np.ndarray: audio, _ = librosa.load(str(audio_path), sr=self.sample_rate, mono=True) return audio.astype(np.float32, copy=False) def predict_probabilities(self, audio: np.ndarray) -> np.ndarray: probabilities: list[np.ndarray] = [] for start in range(0, len(audio), self.chunk_samples): chunk = audio[start : start + self.chunk_samples] if len(chunk) < self.chunk_samples: chunk = np.pad(chunk, (0, self.chunk_samples - len(chunk)), mode="constant") features = self.feature_extractor( chunk, sampling_rate=self.sample_rate, return_tensors="np", ).input_features logits = self.session.run(self.output_names, {self.input_name: features})[0][0] probabilities.append(1.0 / (1.0 + np.exp(-logits))) if not probabilities: return np.array([], dtype=np.float32) return np.concatenate(probabilities, axis=0) def load_model( device: str, compute_type: str, device_index: int, ) -> tuple[WhisperModel, str, str, int]: """ Try to load the model on the requested device; if GPU init fails, fall back to CPU. """ try: model = WhisperModel( "model", device=device, compute_type=compute_type, device_index=device_index if device == "cuda" else 0, ) return model, device, compute_type, (device_index if device == "cuda" else 0) except Exception as exc: if device in {"mps", "cuda"}: fallback_device = "cpu" fallback_compute = "int8" print(f"{device.upper()} unavailable, falling back to CPU (reason: {exc})") model = WhisperModel("model", device=fallback_device, compute_type=fallback_compute) return model, fallback_device, fallback_compute, 0 raise def format_timestamp(seconds: float) -> str: total_centiseconds = int(round(seconds * 100)) minutes, remainder = divmod(total_centiseconds, 6000) secs, centiseconds = divmod(remainder, 100) return f"[{minutes:02d}:{secs:02d}.{centiseconds:02d}]" def detect_asmr_speech_segments( audio: np.ndarray, vad_model: AsmrVadModel, threshold: float, neg_threshold: float | None, min_speech_duration_ms: int, min_silence_duration_ms: int, speech_pad_ms: int, ) -> list[dict[str, float]]: speech_probs = vad_model.predict_probabilities(audio) if speech_probs.size == 0: return [] frame_ms = vad_model.frame_duration_ms min_speech_frames = max(1, int(round(min_speech_duration_ms / frame_ms))) min_silence_frames = max(1, int(round(min_silence_duration_ms / frame_ms))) speech_pad_frames = max(0, int(round(speech_pad_ms / frame_ms))) neg_threshold = max(threshold - 0.15, 0.01) if neg_threshold is None else neg_threshold raw_segments: list[tuple[int, int]] = [] triggered = False current_start = 0 temp_end: int | None = None for frame_idx, speech_prob in enumerate(speech_probs): if speech_prob >= threshold and not triggered: triggered = True current_start = frame_idx temp_end = None continue if not triggered: continue if speech_prob < neg_threshold: if temp_end is None: temp_end = frame_idx elif frame_idx - temp_end >= min_silence_frames: if temp_end - current_start >= min_speech_frames: raw_segments.append((current_start, temp_end)) triggered = False temp_end = None elif temp_end is not None: temp_end = None if triggered: end_frame = temp_end if temp_end is not None else len(speech_probs) if end_frame - current_start >= min_speech_frames: raw_segments.append((current_start, end_frame)) segments: list[dict[str, float]] = [] for idx, (start_frame, end_frame) in enumerate(raw_segments): prev_end = raw_segments[idx - 1][1] if idx > 0 else 0 next_start = raw_segments[idx + 1][0] if idx + 1 < len(raw_segments) else len(speech_probs) padded_start = max(prev_end, start_frame - speech_pad_frames) padded_end = min(next_start, end_frame + speech_pad_frames) segments.append( { "start": padded_start * frame_ms / 1000, "end": padded_end * frame_ms / 1000, } ) return segments def build_clip_timestamps(segments: list[dict[str, float]]) -> list[float]: clip_timestamps: list[float] = [] for segment in segments: clip_timestamps.extend([segment["start"], segment["end"]]) return clip_timestamps def write_lrc(segments: Iterable, output_path: Path) -> None: lines = [] for seg in segments: text = seg.text.strip() if not text: continue lines.append(f"{format_timestamp(seg.start)}{text}") if not lines: print(f"Skipping empty transcript for {output_path}") return output_path.write_text("\n".join(lines) + "\n", encoding="utf-8") def transcribe_file( model: WhisperModel, audio_path: Path, beam_size: int, language: str, task: str, vad: bool, vad_parameters: dict | None, clip_timestamps: str | list[float] = "0", extra_generation_args: dict[str, Any] | None = None, ) -> tuple[list, str, float]: segments_iter, info = model.transcribe( str(audio_path), task=task, beam_size=beam_size, vad_filter=vad, vad_parameters=vad_parameters if vad else None, clip_timestamps=clip_timestamps, language=language, **(extra_generation_args or {}), ) print(f"[{audio_path.name}] Detected language: {info.language} (prob={info.language_probability:.2f})") # Stream segments as they arrive so the console updates in real time. segments = [] for seg in segments_iter: segments.append(seg) print(f"[{seg.start:.2f} -> {seg.end:.2f}] {seg.text}", flush=True) return segments, info.language, info.language_probability def collect_audio_files(path: Path) -> list[Path]: if path.is_file(): return [path] if not path.exists(): raise FileNotFoundError(f"Audio path not found: {path}") return sorted(path.rglob("*.mp3")) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run Whisper inference with local CTranslate2 model.") parser.add_argument( "audio", nargs="?", default="mp3", help="Path to an audio file or directory (default: ./mp3).", ) parser.add_argument("--beam-size", type=int, default=5, help="Beam size for decoding.") parser.add_argument( "--language", default="ja", help="Input language code passed to Whisper (default: ja). Use 'auto' for auto-detection.", ) parser.add_argument( "--task", default="translate", choices=["transcribe", "translate"], help="Whisper task mode (default: translate).", ) parser.add_argument( "--vad-mode", default="asmr", choices=["asmr", "builtin", "none"], help="VAD mode: asmr ONNX model, faster-whisper builtin VAD, or none (default: asmr).", ) parser.add_argument("--no-vad", action="store_true", help=argparse.SUPPRESS) parser.add_argument( "--vad-threshold", type=float, default=0.5, help="Speech probability threshold for VAD. Lower is less aggressive (default: 0.5).", ) parser.add_argument( "--vad-neg-threshold", type=float, default=None, help="Optional silence probability threshold for VAD (useful to smooth speech end).", ) parser.add_argument( "--vad-min-speech-ms", type=int, default=300, help="Minimum speech duration (ms) kept by VAD (default: 300).", ) parser.add_argument( "--vad-min-silence-ms", type=int, default=100, help="Minimum silence (ms) to cut a speech chunk when VAD is enabled (default: 100).", ) parser.add_argument( "--vad-pad-ms", type=int, default=200, help="Padding (ms) added before/after each detected speech chunk (default: 200).", ) parser.add_argument( "--vad-model-path", default="vad_models/Whisper-Vad-EncDec-ASMR-onnx/model.onnx", help="Path to the external ASMR VAD ONNX model.", ) parser.add_argument( "--vad-force-cpu", action="store_true", help="Force the external ASMR VAD to run on CPU.", ) parser.add_argument( "--vad-num-threads", type=int, default=1, help="CPU thread count for the external ASMR VAD (default: 1).", ) parser.add_argument( "--max-initial-timestamp", type=float, default=30.0, help="Maximum initial timestamp passed to Whisper decoding (default: 30).", ) parser.add_argument( "--repetition-penalty", type=float, default=1.1, help="Repetition penalty passed to Whisper decoding (default: 1.1).", ) parser.add_argument( "--device", default="cuda", # ✅ 默认用 CUDA choices=["cpu", "mps", "cuda"], help="Inference device. Use 'cuda' on NVIDIA GPUs, 'mps' on Apple Silicon.", ) parser.add_argument( "--device-index", type=int, default=0, help="CUDA device index (e.g., 0 for the first GPU). Ignored for CPU/MPS.", ) parser.add_argument( "--compute-type", default=None, help=( "Override compute type (e.g., int8, int8_float16, float16). " "Default: CUDA/MPS=float16, CPU=int8." ), ) return parser.parse_args() def main() -> None: args = parse_args() if args.no_vad: args.vad_mode = "none" target = Path(args.audio) audio_files = collect_audio_files(target) if not audio_files: print(f"No .mp3 files found under {target}") return compute_type = args.compute_type or ("float16" if args.device in {"cuda", "mps"} else "int8") model, device_used, compute_used, device_index_used = load_model( device=args.device, compute_type=compute_type, device_index=args.device_index, ) if device_used == "cuda": print(f"Using device={device_used}:{device_index_used}, compute_type={compute_used}") else: print(f"Using device={device_used}, compute_type={compute_used}") builtin_vad_parameters = { "threshold": args.vad_threshold, "neg_threshold": args.vad_neg_threshold, "min_speech_duration_ms": args.vad_min_speech_ms, "min_silence_duration_ms": args.vad_min_silence_ms, "speech_pad_ms": args.vad_pad_ms, } asmr_vad: AsmrVadModel | None = None if args.vad_mode == "builtin": print( "Built-in VAD enabled: " f"threshold={builtin_vad_parameters['threshold']}, " f"neg_threshold={builtin_vad_parameters['neg_threshold']}, " f"min_speech_ms={builtin_vad_parameters['min_speech_duration_ms']}, " f"min_silence_ms={builtin_vad_parameters['min_silence_duration_ms']}, " f"pad_ms={builtin_vad_parameters['speech_pad_ms']}" ) elif args.vad_mode == "asmr": vad_model_path = Path(args.vad_model_path) if not vad_model_path.exists(): raise FileNotFoundError(f"ASMR VAD model not found: {vad_model_path}") asmr_vad = AsmrVadModel( model_path=vad_model_path, force_cpu=args.vad_force_cpu, num_threads=args.vad_num_threads, ) print( "ASMR VAD enabled: " f"model={vad_model_path}, " f"providers={asmr_vad.providers}, " f"threshold={args.vad_threshold}, " f"neg_threshold={args.vad_neg_threshold}, " f"min_speech_ms={args.vad_min_speech_ms}, " f"min_silence_ms={args.vad_min_silence_ms}, " f"pad_ms={args.vad_pad_ms}" ) else: print("VAD disabled") generation_args: dict[str, Any] = { "max_initial_timestamp": args.max_initial_timestamp, "repetition_penalty": args.repetition_penalty, } for idx, audio_path in enumerate(audio_files, start=1): print(f"\n[{idx}/{len(audio_files)}] Processing {audio_path}") try: use_builtin_vad = args.vad_mode == "builtin" clip_timestamps: str | list[float] = "0" if args.vad_mode == "asmr" and asmr_vad is not None: audio = asmr_vad.load_audio(audio_path) speech_segments = detect_asmr_speech_segments( audio=audio, vad_model=asmr_vad, threshold=args.vad_threshold, neg_threshold=args.vad_neg_threshold, min_speech_duration_ms=args.vad_min_speech_ms, min_silence_duration_ms=args.vad_min_silence_ms, speech_pad_ms=args.vad_pad_ms, ) if speech_segments: kept_duration = sum(segment["end"] - segment["start"] for segment in speech_segments) print( "ASMR VAD kept " f"{len(speech_segments)} segments " f"({kept_duration:.2f}s speech)" ) clip_timestamps = build_clip_timestamps(speech_segments) else: print("ASMR VAD found no speech segments; falling back to full-audio decoding.") segments, _, _ = transcribe_file( model=model, audio_path=audio_path, beam_size=args.beam_size, language=args.language, task=args.task, vad=use_builtin_vad, vad_parameters=builtin_vad_parameters if use_builtin_vad else None, clip_timestamps=clip_timestamps, extra_generation_args=generation_args, ) write_lrc(segments, audio_path.with_suffix(".lrc")) except Exception as exc: print(f"Failed to process {audio_path}: {exc}") if __name__ == "__main__": main()