TechScout/techscout/extraction/llm_client.py

152 lines
4.1 KiB
Python

"""
Ollama LLM Client for TechScout.
Handles all LLM interactions for query decomposition, analysis, and scoring.
"""
import json
import logging
import re
import requests
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
logger = logging.getLogger(__name__)
@dataclass
class LLMResponse:
"""Response from LLM."""
content: str
success: bool
error: Optional[str] = None
model: str = ""
tokens_used: int = 0
class OllamaClient:
"""
Client for Ollama local LLM.
"""
def __init__(
self,
base_url: str = "http://localhost:11434",
default_model: str = "mistral-nemo:12b"
):
self.base_url = base_url.rstrip("/")
self.default_model = default_model
def generate(
self,
prompt: str,
model: Optional[str] = None,
system: Optional[str] = None,
temperature: float = 0.1,
max_tokens: int = 4096,
format: Optional[str] = None # "json" for JSON mode
) -> LLMResponse:
"""
Generate a response from Ollama.
Args:
prompt: The user prompt
model: Model to use (defaults to self.default_model)
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens to generate
format: Response format ("json" for JSON mode)
Returns:
LLMResponse object
"""
model = model or self.default_model
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": temperature,
"num_predict": max_tokens,
}
}
if system:
payload["system"] = system
if format == "json":
payload["format"] = "json"
try:
response = requests.post(
f"{self.base_url}/api/generate",
json=payload,
timeout=120
)
response.raise_for_status()
data = response.json()
return LLMResponse(
content=data.get("response", ""),
success=True,
model=model,
tokens_used=data.get("eval_count", 0)
)
except requests.exceptions.Timeout:
return LLMResponse(
content="",
success=False,
error="Request timed out",
model=model
)
except requests.exceptions.RequestException as e:
return LLMResponse(
content="",
success=False,
error=str(e),
model=model
)
def extract_json_from_text(self, text: str) -> Optional[Dict[str, Any]]:
"""Extract JSON from text that might contain other content."""
# Try to find JSON block
json_patterns = [
r'```json\s*([\s\S]*?)\s*```',
r'```\s*([\s\S]*?)\s*```',
r'\{[\s\S]*\}',
]
for pattern in json_patterns:
matches = re.findall(pattern, text)
for match in matches:
try:
return json.loads(match)
except json.JSONDecodeError:
continue
# Try parsing the whole text
try:
return json.loads(text)
except json.JSONDecodeError:
return None
def is_available(self) -> bool:
"""Check if Ollama is running and accessible."""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
return response.status_code == 200
except Exception:
return False
def list_models(self) -> List[str]:
"""List available models."""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=10)
response.raise_for_status()
data = response.json()
return [m["name"] for m in data.get("models", [])]
except Exception as e:
logger.error(f"Failed to list models: {e}")
return []