322 lines
12 KiB
Python
322 lines
12 KiB
Python
"""
|
|
Phase 1: Discovery Pipeline
|
|
|
|
Takes a capability gap, decomposes it into searches, queries multiple sources,
|
|
scores results, and returns ranked candidates.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional
|
|
import uuid
|
|
|
|
from ..config import TechScoutConfig, config as default_config
|
|
from ..extraction.llm_client import OllamaClient
|
|
from ..extraction.decomposer import QueryDecomposer, DecomposedQuery
|
|
from ..extraction.scorer import ResultScorer
|
|
from ..extraction.org_extractor import OrganizationExtractor
|
|
from ..search.web import WebSearcher
|
|
from ..search.base import SearchResult
|
|
from ..sources.sbir import SBIRSearcher
|
|
from ..sources.patents import PatentSearcher
|
|
from ..sources.contracts import ContractSearcher
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class TechnologyCandidate:
|
|
"""A scored technology candidate."""
|
|
id: str
|
|
title: str
|
|
organization: str
|
|
description: str
|
|
source_type: str
|
|
source: str
|
|
url: str
|
|
score: float
|
|
relevance_score: float
|
|
trl_estimate: Optional[int]
|
|
award_amount: Optional[float]
|
|
published_date: Optional[str]
|
|
award_id: Optional[str]
|
|
patent_number: Optional[str]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"title": self.title,
|
|
"organization": self.organization,
|
|
"description": self.description,
|
|
"source_type": self.source_type,
|
|
"source": self.source,
|
|
"url": self.url,
|
|
"score": self.score,
|
|
"relevance_score": self.relevance_score,
|
|
"trl_estimate": self.trl_estimate,
|
|
"award_amount": self.award_amount,
|
|
"published_date": self.published_date,
|
|
"award_id": self.award_id,
|
|
"patent_number": self.patent_number,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class DiscoveryResult:
|
|
"""Result of Phase 1 discovery."""
|
|
id: str
|
|
capability_gap: str
|
|
timestamp: str
|
|
decomposition: Dict[str, Any]
|
|
candidates: List[TechnologyCandidate]
|
|
source_stats: Dict[str, int]
|
|
total_results_found: int
|
|
search_duration_seconds: float
|
|
success: bool = True
|
|
error: Optional[str] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"capability_gap": self.capability_gap,
|
|
"timestamp": self.timestamp,
|
|
"decomposition": self.decomposition,
|
|
"candidates": [c.to_dict() for c in self.candidates],
|
|
"source_stats": self.source_stats,
|
|
"total_results_found": self.total_results_found,
|
|
"search_duration_seconds": self.search_duration_seconds,
|
|
"success": self.success,
|
|
"error": self.error,
|
|
}
|
|
|
|
def save(self, path: Path):
|
|
"""Save result to JSON file."""
|
|
with open(path, "w") as f:
|
|
json.dump(self.to_dict(), f, indent=2)
|
|
|
|
|
|
class DiscoveryPipeline:
|
|
"""
|
|
Phase 1 Discovery Pipeline.
|
|
|
|
1. Decompose capability gap into search queries
|
|
2. Search multiple sources (SBIR, patents, contracts, web)
|
|
3. Score and rank results
|
|
4. Return top candidates
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: Optional[TechScoutConfig] = None,
|
|
model: str = "mistral-nemo:12b"
|
|
):
|
|
self.config = config or default_config
|
|
self.model = model
|
|
|
|
# Initialize components
|
|
self.llm_client = OllamaClient(
|
|
base_url=self.config.ollama.base_url,
|
|
default_model=model
|
|
)
|
|
self.decomposer = QueryDecomposer(self.llm_client, model)
|
|
self.scorer = ResultScorer(self.llm_client, model)
|
|
self.org_extractor = OrganizationExtractor(self.llm_client, model)
|
|
|
|
# Initialize searchers
|
|
self.web_searcher = WebSearcher()
|
|
self.sbir_searcher = SBIRSearcher()
|
|
self.patent_searcher = PatentSearcher()
|
|
self.contract_searcher = ContractSearcher()
|
|
|
|
def discover(
|
|
self,
|
|
capability_gap: str,
|
|
max_results: int = 50,
|
|
use_llm_scoring: bool = True,
|
|
sources: Optional[List[str]] = None
|
|
) -> DiscoveryResult:
|
|
"""
|
|
Run Phase 1 discovery.
|
|
|
|
Args:
|
|
capability_gap: Natural language description of capability need
|
|
max_results: Maximum candidates to return
|
|
use_llm_scoring: Use LLM for relevance scoring (slower but better)
|
|
sources: Which sources to search (default: all)
|
|
|
|
Returns:
|
|
DiscoveryResult with ranked candidates
|
|
"""
|
|
start_time = datetime.now()
|
|
result_id = str(uuid.uuid4())[:8]
|
|
|
|
sources = sources or ["sbir", "patents", "contracts", "web", "news"]
|
|
|
|
logger.info(f"Starting discovery for: {capability_gap[:100]}...")
|
|
|
|
# Step 1: Decompose the query
|
|
logger.info("Decomposing capability gap into search queries...")
|
|
decomposition = self.decomposer.decompose(capability_gap)
|
|
|
|
if not decomposition.success:
|
|
return DiscoveryResult(
|
|
id=result_id,
|
|
capability_gap=capability_gap,
|
|
timestamp=datetime.now().isoformat(),
|
|
decomposition={},
|
|
candidates=[],
|
|
source_stats={},
|
|
total_results_found=0,
|
|
search_duration_seconds=0,
|
|
success=False,
|
|
error=f"Query decomposition failed: {decomposition.error}"
|
|
)
|
|
|
|
logger.info(f"Generated {len(decomposition.search_queries)} search queries")
|
|
|
|
# Step 2: Search all sources
|
|
all_results: List[SearchResult] = []
|
|
source_stats: Dict[str, int] = {}
|
|
|
|
# SBIR/STTR
|
|
if "sbir" in sources:
|
|
logger.info("Searching SBIR/STTR awards...")
|
|
for query in decomposition.sbir_queries[:3]:
|
|
try:
|
|
results = self.sbir_searcher.search(query, max_results=15)
|
|
all_results.extend(results)
|
|
except Exception as e:
|
|
logger.warning(f"SBIR search failed: {e}")
|
|
source_stats["sbir"] = len([r for r in all_results if r.source_type == "sbir"])
|
|
|
|
# Patents
|
|
if "patents" in sources:
|
|
logger.info("Searching patents...")
|
|
for query in decomposition.patent_queries[:3]:
|
|
try:
|
|
results = self.patent_searcher.search(query, max_results=15)
|
|
all_results.extend(results)
|
|
except Exception as e:
|
|
logger.warning(f"Patent search failed: {e}")
|
|
source_stats["patents"] = len([r for r in all_results if r.source_type == "patent"])
|
|
|
|
# Federal contracts
|
|
if "contracts" in sources:
|
|
logger.info("Searching federal contracts...")
|
|
for query in decomposition.search_queries[:2]:
|
|
try:
|
|
results = self.contract_searcher.search_dod(query, max_results=10)
|
|
all_results.extend(results)
|
|
except Exception as e:
|
|
logger.warning(f"Contract search failed: {e}")
|
|
source_stats["contracts"] = len([r for r in all_results if r.source_type == "contract"])
|
|
|
|
# Web search
|
|
if "web" in sources:
|
|
logger.info("Searching web...")
|
|
for query in decomposition.search_queries[:4]:
|
|
try:
|
|
results = self.web_searcher.search(query, max_results=10)
|
|
all_results.extend(results)
|
|
except Exception as e:
|
|
logger.warning(f"Web search failed: {e}")
|
|
source_stats["web"] = len([r for r in all_results if r.source_type == "web"])
|
|
|
|
# News search
|
|
if "news" in sources:
|
|
logger.info("Searching defense news...")
|
|
for query in decomposition.news_queries[:2]:
|
|
try:
|
|
results = self.web_searcher.search(query, max_results=10, news_only=True)
|
|
all_results.extend(results)
|
|
except Exception as e:
|
|
logger.warning(f"News search failed: {e}")
|
|
source_stats["news"] = len([r for r in all_results if r.source_type == "news"])
|
|
|
|
total_found = len(all_results)
|
|
logger.info(f"Found {total_found} total results")
|
|
|
|
# Step 3: Deduplicate by URL
|
|
seen_urls = set()
|
|
unique_results = []
|
|
for result in all_results:
|
|
if result.url and result.url not in seen_urls:
|
|
seen_urls.add(result.url)
|
|
unique_results.append(result)
|
|
|
|
logger.info(f"After deduplication: {len(unique_results)} unique results")
|
|
|
|
# Step 3.5: Extract organizations for web/news results that don't have them
|
|
logger.info("Extracting organizations from web/news results...")
|
|
results_needing_org = [
|
|
(i, r) for i, r in enumerate(unique_results)
|
|
if r.source_type in ("web", "news", "government", "academic") and not r.organization
|
|
]
|
|
|
|
if results_needing_org:
|
|
# Extract organizations using hybrid regex + LLM approach
|
|
items_to_extract = [(r.title, r.snippet) for _, r in results_needing_org]
|
|
extractions = self.org_extractor.extract_batch(items_to_extract, use_llm_fallback=True)
|
|
|
|
# Update results with extracted organizations
|
|
for (idx, result), extraction in zip(results_needing_org, extractions):
|
|
if extraction.organization:
|
|
unique_results[idx].organization = extraction.organization
|
|
logger.debug(f"Extracted org '{extraction.organization}' from '{result.title[:50]}' (method: {extraction.method})")
|
|
|
|
extracted_count = sum(1 for e in extractions if e.organization)
|
|
logger.info(f"Extracted organizations for {extracted_count}/{len(results_needing_org)} web/news results")
|
|
|
|
# Step 4: Score and rank
|
|
logger.info("Scoring results...")
|
|
scored_results = self.scorer.score_results(
|
|
unique_results,
|
|
capability_gap,
|
|
target_trl=decomposition.target_trl_range,
|
|
use_llm=use_llm_scoring
|
|
)
|
|
|
|
# Step 5: Convert to candidates
|
|
candidates = []
|
|
for i, result in enumerate(scored_results[:max_results]):
|
|
candidates.append(TechnologyCandidate(
|
|
id=f"{result_id}-{i}",
|
|
title=result.title,
|
|
organization=result.organization or "Unknown",
|
|
description=result.snippet,
|
|
source_type=result.source_type,
|
|
source=result.source,
|
|
url=result.url,
|
|
score=result.final_score,
|
|
relevance_score=result.relevance_score,
|
|
trl_estimate=result.trl_estimate,
|
|
award_amount=result.award_amount,
|
|
published_date=result.published_date,
|
|
award_id=result.award_id,
|
|
patent_number=result.patent_number,
|
|
))
|
|
|
|
duration = (datetime.now() - start_time).total_seconds()
|
|
|
|
discovery_result = DiscoveryResult(
|
|
id=result_id,
|
|
capability_gap=capability_gap,
|
|
timestamp=datetime.now().isoformat(),
|
|
decomposition=decomposition.to_dict(),
|
|
candidates=candidates,
|
|
source_stats=source_stats,
|
|
total_results_found=total_found,
|
|
search_duration_seconds=duration,
|
|
success=True
|
|
)
|
|
|
|
# Save result
|
|
save_path = self.config.analyses_dir / f"discovery_{result_id}.json"
|
|
discovery_result.save(save_path)
|
|
logger.info(f"Saved discovery result to {save_path}")
|
|
|
|
return discovery_result
|