Custom Strategies
Learn how to create custom strategies for any LLM provider.
Basic Custom Strategy
from async_batch_llm import LLMCallStrategy
class OpenAIStrategy(LLMCallStrategy[str]):
def __init__(self, client, model: str):
self.client = client
self.model = model
async def execute(self, prompt: str, attempt: int, timeout: float):
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}]
)
output = response.choices[0].message.content
tokens = {
"input_tokens": response.usage.prompt_tokens,
"output_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
}
return output, tokens
Resource Management
Use prepare() and cleanup() for resource lifecycle:
class CachedStrategy(LLMCallStrategy[str]):
def __init__(self, client, system_instruction: str):
self.client = client
self.system_instruction = system_instruction
self.cache_name = None
async def prepare(self):
"""Create cache before processing."""
self.cache_name = await self.client.create_cache(
content=self.system_instruction
)
async def execute(self, prompt: str, attempt: int, timeout: float):
# Use the cached content
response = await self.client.generate(
prompt=prompt,
cache_name=self.cache_name
)
return response.text, response.usage
async def cleanup(self):
"""Delete cache after processing."""
if self.cache_name:
await self.client.delete_cache(self.cache_name)
Error Handling
Use on_error() to track failures and adjust behavior:
from pydantic import ValidationError
class SmartRetryStrategy(LLMCallStrategy[dict]):
def __init__(self, client):
self.client = client
self.validation_failures = 0
async def on_error(self, exception: Exception, attempt: int):
"""Track validation errors for smart escalation."""
if isinstance(exception, ValidationError):
self.validation_failures += 1
async def execute(self, prompt: str, attempt: int, timeout: float):
# Use cheaper model initially, escalate only on validation errors
if self.validation_failures == 0:
model = "cheap-model"
elif self.validation_failures == 1:
model = "medium-model"
else:
model = "expensive-model"
response = await self.client.generate(prompt, model=model)
return response.output, response.tokens
Progressive Temperature
Increase temperature on retry for better success rates:
class ProgressiveTempStrategy(LLMCallStrategy[str]):
def __init__(self, client, temperatures=None):
self.client = client
self.temperatures = temperatures or [0.0, 0.5, 1.0]
async def execute(self, prompt: str, attempt: int, timeout: float):
# Use progressively higher temperature on retries
temp_index = min(attempt - 1, len(self.temperatures) - 1)
temperature = self.temperatures[temp_index]
response = await self.client.generate(
prompt=prompt,
temperature=temperature
)
return response.text, response.usage
Anthropic Example
from anthropic import AsyncAnthropic
class AnthropicStrategy(LLMCallStrategy[str]):
def __init__(self, client: AsyncAnthropic, model: str = "claude-3-5-sonnet-20241022"):
self.client = client
self.model = model
async def execute(self, prompt: str, attempt: int, timeout: float):
response = await self.client.messages.create(
model=self.model,
max_tokens=1024,
messages=[{"role": "user", "content": prompt}]
)
output = response.content[0].text
tokens = {
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
}
return output, tokens
Usage
from async_batch_llm import ParallelBatchProcessor, LLMWorkItem, ProcessorConfig
async def main():
# Use your custom strategy
strategy = OpenAIStrategy(client=openai_client, model="gpt-4")
config = ProcessorConfig(max_workers=5)
async with ParallelBatchProcessor(config=config) as processor:
await processor.add_work(
LLMWorkItem(
item_id="test",
strategy=strategy,
prompt="Hello!"
)
)
result = await processor.process_all()