Skip to content

Advanced Patterns

Smart Model Escalation

Save costs by starting with cheap models and escalating only on validation errors:

from pydantic import ValidationError
from async_batch_llm import LLMCallStrategy

class SmartModelEscalation(LLMCallStrategy[dict]):
    MODELS = [
        "gemini-2.5-flash-lite",  # Cheapest
        "gemini-2.5-flash",       # Medium
        "gemini-2.5-pro",         # Most capable
    ]

    def __init__(self, client):
        self.client = client
        self.validation_failures = 0

    async def on_error(self, exception: Exception, attempt: int):
        """Only escalate on validation errors, not network/rate limit errors."""
        if isinstance(exception, ValidationError):
            self.validation_failures += 1

    async def execute(self, prompt: str, attempt: int, timeout: float):
        # Network error on attempt 2? Retry with same cheap model
        # Validation error on attempt 2? Escalate to better model
        model_index = min(self.validation_failures, len(self.MODELS) - 1)
        model = self.MODELS[model_index]

        response = await self.client.generate(prompt, model=model)
        return response.output, response.tokens

Cost savings: 60-80% vs. always using the best model.

Smart Retry with Validation Feedback

Tell the LLM exactly what failed on retry:

class SmartRetryStrategy(LLMCallStrategy[PersonData]):
    def __init__(self, client):
        self.client = client
        self.last_error = None
        self.last_response = None

    async def on_error(self, exception: Exception, attempt: int):
        if isinstance(exception, ValidationError):
            self.last_error = exception

    async def execute(self, prompt: str, attempt: int, timeout: float):
        if attempt == 1:
            final_prompt = prompt
        else:
            # Create retry prompt with field-level feedback
            final_prompt = self._create_retry_prompt(prompt)

        try:
            response = await self.client.generate(final_prompt)
            output = PersonData.model_validate_json(response.text)
            return output, tokens
        except ValidationError as e:
            self.last_response = response.text
            raise

    def _create_retry_prompt(self, original_prompt: str) -> str:
        # Parse self.last_error to identify which fields failed
        # Build prompt like: "These fields succeeded: [age]. Fix these: [name, email]"
        return retry_prompt

Shared Context Caching

Dramatically reduce costs for RAG and repeated context:

from async_batch_llm import GeminiCachedModel, GeminiStrategy
from google import genai
from google.genai.types import Content

async def process_with_caching():
    client = genai.Client(api_key="your-key")

    # Load large RAG context once
    with open("knowledge_base.txt") as f:
        rag_context = f.read()  # Could be 100K+ tokens

    # Model manages cache lifecycle (prepare/cleanup)
    cached_model = GeminiCachedModel(
        "gemini-2.5-flash", client,
        cached_content=[Content(parts=[{"text": rag_context}], role="user")],
    )
    strategy = GeminiStrategy(cached_model, response_parser=lambda r: r.text)

    config = ProcessorConfig(max_workers=5)

    async with ParallelBatchProcessor(config=config) as processor:
        # All 100 queries share the same cached context
        for i in range(100):
            await processor.add_work(
                LLMWorkItem(
                    item_id=f"query_{i}",
                    strategy=strategy,
                    prompt=f"Answer based on context: {questions[i]}"
                )
            )

        result = await processor.process_all()
        # Cache automatically cleaned up on exit

Cost savings: ~90% for input tokens on cached content.

Middleware for Custom Logic

Inject custom behavior into the processing pipeline:

from async_batch_llm.middleware import Middleware
from async_batch_llm import LLMWorkItem, WorkItemResult

class LoggingMiddleware(Middleware):
    async def before_process(self, work_item: LLMWorkItem):
        print(f"Starting {work_item.item_id}")

    async def after_process(self, result: WorkItemResult):
        if result.success:
            print(f"Success: {result.item_id}")
        else:
            print(f"Failed: {result.item_id} - {result.error}")

    async def on_retry(self, work_item: LLMWorkItem, attempt: int, error: Exception):
        print(f"Retry {attempt} for {work_item.item_id}: {error}")

async def main():
    logging_middleware = LoggingMiddleware()

    async with ParallelBatchProcessor(
        config=config,
        middlewares=[logging_middleware]
    ) as processor:
        # Add work items...
        result = await processor.process_all()

Custom Observers

Track custom metrics:

from async_batch_llm.observers import BaseObserver, ProcessingEvent
from async_batch_llm import LLMWorkItem, WorkItemResult
from typing import Any

class CostTracker(BaseObserver):
    def __init__(self):
        self.total_cost = 0.0
        self.total_tokens = 0

    async def on_event(self, event: ProcessingEvent, data: dict[str, Any]) -> None:
        if event == ProcessingEvent.ITEM_COMPLETED:
            # Calculate cost based on tokens
            tokens = data.get("tokens", {})
            total = tokens.get("total_tokens", 0)
            self.total_tokens += total
            self.total_cost += total * 0.00001  # Example rate

async def main():
    cost_tracker = CostTracker()

    async with ParallelBatchProcessor(
        config=config,
        observers=[cost_tracker]
    ) as processor:
        # Add work items...
        result = await processor.process_all()

        print(f"Total tokens: {cost_tracker.total_tokens}")
        print(f"Estimated cost: ${cost_tracker.total_cost:.4f}")

Dynamic Worker Scaling

Adjust workers based on rate limits:

async def adaptive_processing():
    config = ProcessorConfig(
        max_workers=10,  # Start optimistic
        timeout_per_item=30.0
    )

    async with ParallelBatchProcessor(config=config) as processor:
        # Add work...
        result = await processor.process_all()

        # Check if rate limited
        stats = await processor.get_stats()
        if stats["rate_limit_count"] > 5:
            # Too many rate limits, reduce workers for next batch
            processor.config.max_workers = 3