Align infer defaults with generation config and force ja translate

This commit is contained in:
xuwei 2026-03-03 20:50:11 +08:00
parent 8172ca6b9d
commit 26ea9aba51

View file

@ -56,7 +56,6 @@ def write_lrc(segments: Iterable, output_path: Path) -> None:
def transcribe_file( def transcribe_file(
model: WhisperModel, model: WhisperModel,
audio_path: Path, audio_path: Path,
language: str | None,
beam_size: int, beam_size: int,
vad: bool, vad: bool,
vad_parameters: dict | None, vad_parameters: dict | None,
@ -64,10 +63,11 @@ def transcribe_file(
) -> tuple[list, str, float]: ) -> tuple[list, str, float]:
segments_iter, info = model.transcribe( segments_iter, info = model.transcribe(
str(audio_path), str(audio_path),
task="translate",
beam_size=beam_size, beam_size=beam_size,
vad_filter=vad, vad_filter=vad,
vad_parameters=vad_parameters if vad else None, vad_parameters=vad_parameters if vad else None,
language=language, language="ja",
**(extra_generation_args or {}), **(extra_generation_args or {}),
) )
print(f"[{audio_path.name}] Detected language: {info.language} (prob={info.language_probability:.2f})") print(f"[{audio_path.name}] Detected language: {info.language} (prob={info.language_probability:.2f})")
@ -97,14 +97,13 @@ def parse_args() -> argparse.Namespace:
default="mp3", default="mp3",
help="Path to an audio file or directory (default: ./mp3).", help="Path to an audio file or directory (default: ./mp3).",
) )
parser.add_argument("--language", help="Language code (e.g., zh, en). Default: ja.", default="ja")
parser.add_argument("--beam-size", type=int, default=5, help="Beam size for decoding.") parser.add_argument("--beam-size", type=int, default=5, help="Beam size for decoding.")
parser.add_argument("--no-vad", action="store_true", help="Disable VAD (voice activity detection) filtering.") parser.add_argument("--no-vad", action="store_true", help="Disable VAD (voice activity detection) filtering.")
parser.add_argument( parser.add_argument(
"--vad-threshold", "--vad-threshold",
type=float, type=float,
default=0.35, default=0.5,
help="Speech probability threshold for VAD. Lower is less aggressive (default: 0.35).", help="Speech probability threshold for VAD. Lower is less aggressive (default: 0.5).",
) )
parser.add_argument( parser.add_argument(
"--vad-neg-threshold", "--vad-neg-threshold",
@ -185,7 +184,7 @@ def main() -> None:
) )
generation_args: dict[str, Any] = { generation_args: dict[str, Any] = {
# "max_initial_timestamp": 10, "max_initial_timestamp": 30,
"repetition_penalty": 1.1, "repetition_penalty": 1.1,
} }
@ -195,7 +194,6 @@ def main() -> None:
segments, _, _ = transcribe_file( segments, _, _ = transcribe_file(
model=model, model=model,
audio_path=audio_path, audio_path=audio_path,
language=args.language,
beam_size=args.beam_size, beam_size=args.beam_size,
vad=not args.no_vad, vad=not args.no_vad,
vad_parameters=vad_parameters, vad_parameters=vad_parameters,