mammoth-shoe-93270
07/27/2025, 1:54 AMtall-airline-16756
07/27/2025, 1:54 AMtall-airline-16756
07/27/2025, 1:55 AMfrom __future__ import annotations
import asyncio
import os
from dataclasses import dataclass
import aiohttp
from livekit import rtc
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
stt,
)
from livekit.agents.stt import SpeechEventType, STTCapabilities
from livekit.agents.types import NOT_GIVEN, NotGivenOr
from livekit.agents.utils import AudioBuffer, is_given
@dataclass
class _STTOptions:
language: str = "en"
api_key: str = ""
base_url: str = ""
class CustomSTT(stt.STT):
def __init__(
self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
api_key: NotGivenOr[str] = NOT_GIVEN,
base_url: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
):
# Set streaming=False since this provider doesn’t support streaming
super().__init__(capabilities=STTCapabilities(streaming=False, interim_results=True))
self._api_key = api_key if is_given(api_key) else os.getenv("CUSTOM_STT_API_KEY")
if not self._api_key:
raise ValueError("API key is required. Set CUSTOM_STT_API_KEY environment variable")
self._opts = _STTOptions(
language=language if is_given(language) else "en",
api_key=self._api_key,
base_url=base_url if is_given(base_url) else "<YOUR_API_BASE_URL>"
)
self._session = http_session
def update_options(self, *, language: NotGivenOr[str] = NOT_GIVEN) -> None:
if is_given(language):
self._opts.language = language
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
from livekit.agents import utils
self._session = utils.http_context.http_session()
return self._session
async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
try:
if is_given(language):
self._opts.language = language
# Convert audio buffer to WAV bytes
wav_bytes = rtc.combine_audio_frames(buffer).to_wav_bytes()
# Prepare form data for API request
form_data = aiohttp.FormData()
form_data.add_field("audio", wav_bytes, filename="audio.wav", content_type="audio/wav")
form_data.add_field("language", self._opts.language)
headers = {
"Authorization": f"Bearer {self._opts.api_key}",
"User-Agent": "LiveKit-Agents"
}
# Make API request
async with self._ensure_session().post(
self._opts.base_url,
data=form_data,
headers=headers,
timeout=aiohttp.ClientTimeout(
total=30,
sock_connect=conn_options.timeout,
),
) as response:
if response.status != 200:
error_text = await response.text()
raise APIStatusError(
message=f"API Error: {error_text}",
status_code=response.status,
)
response_data = await response.json()
text = response_data.get("transcript", "")
confidence = response_data.get("confidence", 1.0)
return self._transcription_to_speech_event(text=text, confidence=confidence)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=None,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
def _transcription_to_speech_event(self, text: str, confidence: float = 1.0) -> stt.SpeechEvent:
return stt.SpeechEvent(
type=SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[stt.SpeechData(text=text, language=self._opts.language)],
)
async def aclose(self) -> None:
if self._session:
await self._session.close()
Using with VAD and StreamAdapter
Since your STT provider doesn’t support streaming, you’ll need to use VAD (Voice Activity Detection) with StreamAdapter
to buffer audio until speech segments are complete:
from livekit import agents, rtc
from livekit.plugins import silero
async def process_track(ctx: agents.JobContext, track: rtc.Track):
# Create your custom non-streaming STT
custom_stt = CustomSTT()
# Set up VAD to detect speech segments
vad = silero.VAD.load(
min_speech_duration=0.1,
min_silence_duration=0.5,
)
vad_stream = vad.stream()
# StreamAdapter will buffer audio until VAD emits END_SPEAKING event
stt = agents.stt.StreamAdapter(custom_stt, vad_stream)
stt_stream = stt.stream()
audio_stream = rtc.AudioStream(track)
# Process transcription results
ctx.create_task(process_transcription(stt_stream))
# Feed audio frames to the STT stream
async for audio_event in audio_stream:
stt_stream.push_frame(audio_event.frame)
stt_stream.end_input()
async def process_transcription(stream):
async for event in stream:
if event.type == agents.stt.SpeechEventType.FINAL_TRANSCRIPT:
text = event.alternatives[0].text
print(f"Transcribed: {text}")
Key points to note:
• Set streaming=False
in STTCapabilities
.
• Implement the _recognize_impl
method to process complete audio buffers.
• Convert audio to the appropriate format (e.g., WAV bytes).
• Handle API errors appropriately.
• Use VAD and StreamAdapter
to enable streaming-like behavior for non-streaming providers.
This approach ensures that your custom STT module can handle non-streaming APIs effectively while simulating streaming behavior.
Sources: livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py | Speech-to-text (STT) integrations | LiveKit Docs | Working with plugins | LiveKit Docs