Align infer defaults with generation config and force ja translate
This commit is contained in:
parent
8172ca6b9d
commit
26ea9aba51
1 changed files with 5 additions and 7 deletions
12
infer.py
12
infer.py
|
|
@ -56,7 +56,6 @@ def write_lrc(segments: Iterable, output_path: Path) -> None:
|
|||
def transcribe_file(
|
||||
model: WhisperModel,
|
||||
audio_path: Path,
|
||||
language: str | None,
|
||||
beam_size: int,
|
||||
vad: bool,
|
||||
vad_parameters: dict | None,
|
||||
|
|
@ -64,10 +63,11 @@ def transcribe_file(
|
|||
) -> tuple[list, str, float]:
|
||||
segments_iter, info = model.transcribe(
|
||||
str(audio_path),
|
||||
task="translate",
|
||||
beam_size=beam_size,
|
||||
vad_filter=vad,
|
||||
vad_parameters=vad_parameters if vad else None,
|
||||
language=language,
|
||||
language="ja",
|
||||
**(extra_generation_args or {}),
|
||||
)
|
||||
print(f"[{audio_path.name}] Detected language: {info.language} (prob={info.language_probability:.2f})")
|
||||
|
|
@ -97,14 +97,13 @@ def parse_args() -> argparse.Namespace:
|
|||
default="mp3",
|
||||
help="Path to an audio file or directory (default: ./mp3).",
|
||||
)
|
||||
parser.add_argument("--language", help="Language code (e.g., zh, en). Default: ja.", default="ja")
|
||||
parser.add_argument("--beam-size", type=int, default=5, help="Beam size for decoding.")
|
||||
parser.add_argument("--no-vad", action="store_true", help="Disable VAD (voice activity detection) filtering.")
|
||||
parser.add_argument(
|
||||
"--vad-threshold",
|
||||
type=float,
|
||||
default=0.35,
|
||||
help="Speech probability threshold for VAD. Lower is less aggressive (default: 0.35).",
|
||||
default=0.5,
|
||||
help="Speech probability threshold for VAD. Lower is less aggressive (default: 0.5).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vad-neg-threshold",
|
||||
|
|
@ -185,7 +184,7 @@ def main() -> None:
|
|||
)
|
||||
|
||||
generation_args: dict[str, Any] = {
|
||||
# "max_initial_timestamp": 10,
|
||||
"max_initial_timestamp": 30,
|
||||
"repetition_penalty": 1.1,
|
||||
}
|
||||
|
||||
|
|
@ -195,7 +194,6 @@ def main() -> None:
|
|||
segments, _, _ = transcribe_file(
|
||||
model=model,
|
||||
audio_path=audio_path,
|
||||
language=args.language,
|
||||
beam_size=args.beam_size,
|
||||
vad=not args.no_vad,
|
||||
vad_parameters=vad_parameters,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue