152 lines
4.1 KiB
Python
152 lines
4.1 KiB
Python
"""
|
|
Ollama LLM Client for TechScout.
|
|
Handles all LLM interactions for query decomposition, analysis, and scoring.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import requests
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Dict, Any, List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class LLMResponse:
|
|
"""Response from LLM."""
|
|
content: str
|
|
success: bool
|
|
error: Optional[str] = None
|
|
model: str = ""
|
|
tokens_used: int = 0
|
|
|
|
|
|
class OllamaClient:
|
|
"""
|
|
Client for Ollama local LLM.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str = "http://localhost:11434",
|
|
default_model: str = "mistral-nemo:12b"
|
|
):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.default_model = default_model
|
|
|
|
def generate(
|
|
self,
|
|
prompt: str,
|
|
model: Optional[str] = None,
|
|
system: Optional[str] = None,
|
|
temperature: float = 0.1,
|
|
max_tokens: int = 4096,
|
|
format: Optional[str] = None # "json" for JSON mode
|
|
) -> LLMResponse:
|
|
"""
|
|
Generate a response from Ollama.
|
|
|
|
Args:
|
|
prompt: The user prompt
|
|
model: Model to use (defaults to self.default_model)
|
|
system: System prompt
|
|
temperature: Sampling temperature
|
|
max_tokens: Max tokens to generate
|
|
format: Response format ("json" for JSON mode)
|
|
|
|
Returns:
|
|
LLMResponse object
|
|
"""
|
|
model = model or self.default_model
|
|
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": temperature,
|
|
"num_predict": max_tokens,
|
|
}
|
|
}
|
|
|
|
if system:
|
|
payload["system"] = system
|
|
|
|
if format == "json":
|
|
payload["format"] = "json"
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self.base_url}/api/generate",
|
|
json=payload,
|
|
timeout=120
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
return LLMResponse(
|
|
content=data.get("response", ""),
|
|
success=True,
|
|
model=model,
|
|
tokens_used=data.get("eval_count", 0)
|
|
)
|
|
|
|
except requests.exceptions.Timeout:
|
|
return LLMResponse(
|
|
content="",
|
|
success=False,
|
|
error="Request timed out",
|
|
model=model
|
|
)
|
|
except requests.exceptions.RequestException as e:
|
|
return LLMResponse(
|
|
content="",
|
|
success=False,
|
|
error=str(e),
|
|
model=model
|
|
)
|
|
|
|
def extract_json_from_text(self, text: str) -> Optional[Dict[str, Any]]:
|
|
"""Extract JSON from text that might contain other content."""
|
|
# Try to find JSON block
|
|
json_patterns = [
|
|
r'```json\s*([\s\S]*?)\s*```',
|
|
r'```\s*([\s\S]*?)\s*```',
|
|
r'\{[\s\S]*\}',
|
|
]
|
|
|
|
for pattern in json_patterns:
|
|
matches = re.findall(pattern, text)
|
|
for match in matches:
|
|
try:
|
|
return json.loads(match)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# Try parsing the whole text
|
|
try:
|
|
return json.loads(text)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if Ollama is running and accessible."""
|
|
try:
|
|
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
|
return response.status_code == 200
|
|
except Exception:
|
|
return False
|
|
|
|
def list_models(self) -> List[str]:
|
|
"""List available models."""
|
|
try:
|
|
response = requests.get(f"{self.base_url}/api/tags", timeout=10)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return [m["name"] for m in data.get("models", [])]
|
|
except Exception as e:
|
|
logger.error(f"Failed to list models: {e}")
|
|
return []
|