Align infer defaults with generation config and force ja translate
This commit is contained in:
parent
8172ca6b9d
commit
26ea9aba51
1 changed files with 5 additions and 7 deletions
12
infer.py
12
infer.py
|
|
@ -56,7 +56,6 @@ def write_lrc(segments: Iterable, output_path: Path) -> None:
|
||||||
def transcribe_file(
|
def transcribe_file(
|
||||||
model: WhisperModel,
|
model: WhisperModel,
|
||||||
audio_path: Path,
|
audio_path: Path,
|
||||||
language: str | None,
|
|
||||||
beam_size: int,
|
beam_size: int,
|
||||||
vad: bool,
|
vad: bool,
|
||||||
vad_parameters: dict | None,
|
vad_parameters: dict | None,
|
||||||
|
|
@ -64,10 +63,11 @@ def transcribe_file(
|
||||||
) -> tuple[list, str, float]:
|
) -> tuple[list, str, float]:
|
||||||
segments_iter, info = model.transcribe(
|
segments_iter, info = model.transcribe(
|
||||||
str(audio_path),
|
str(audio_path),
|
||||||
|
task="translate",
|
||||||
beam_size=beam_size,
|
beam_size=beam_size,
|
||||||
vad_filter=vad,
|
vad_filter=vad,
|
||||||
vad_parameters=vad_parameters if vad else None,
|
vad_parameters=vad_parameters if vad else None,
|
||||||
language=language,
|
language="ja",
|
||||||
**(extra_generation_args or {}),
|
**(extra_generation_args or {}),
|
||||||
)
|
)
|
||||||
print(f"[{audio_path.name}] Detected language: {info.language} (prob={info.language_probability:.2f})")
|
print(f"[{audio_path.name}] Detected language: {info.language} (prob={info.language_probability:.2f})")
|
||||||
|
|
@ -97,14 +97,13 @@ def parse_args() -> argparse.Namespace:
|
||||||
default="mp3",
|
default="mp3",
|
||||||
help="Path to an audio file or directory (default: ./mp3).",
|
help="Path to an audio file or directory (default: ./mp3).",
|
||||||
)
|
)
|
||||||
parser.add_argument("--language", help="Language code (e.g., zh, en). Default: ja.", default="ja")
|
|
||||||
parser.add_argument("--beam-size", type=int, default=5, help="Beam size for decoding.")
|
parser.add_argument("--beam-size", type=int, default=5, help="Beam size for decoding.")
|
||||||
parser.add_argument("--no-vad", action="store_true", help="Disable VAD (voice activity detection) filtering.")
|
parser.add_argument("--no-vad", action="store_true", help="Disable VAD (voice activity detection) filtering.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad-threshold",
|
"--vad-threshold",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.35,
|
default=0.5,
|
||||||
help="Speech probability threshold for VAD. Lower is less aggressive (default: 0.35).",
|
help="Speech probability threshold for VAD. Lower is less aggressive (default: 0.5).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad-neg-threshold",
|
"--vad-neg-threshold",
|
||||||
|
|
@ -185,7 +184,7 @@ def main() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
generation_args: dict[str, Any] = {
|
generation_args: dict[str, Any] = {
|
||||||
# "max_initial_timestamp": 10,
|
"max_initial_timestamp": 30,
|
||||||
"repetition_penalty": 1.1,
|
"repetition_penalty": 1.1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -195,7 +194,6 @@ def main() -> None:
|
||||||
segments, _, _ = transcribe_file(
|
segments, _, _ = transcribe_file(
|
||||||
model=model,
|
model=model,
|
||||||
audio_path=audio_path,
|
audio_path=audio_path,
|
||||||
language=args.language,
|
|
||||||
beam_size=args.beam_size,
|
beam_size=args.beam_size,
|
||||||
vad=not args.no_vad,
|
vad=not args.no_vad,
|
||||||
vad_parameters=vad_parameters,
|
vad_parameters=vad_parameters,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue