#!/usr/bin/env python3
"""
Script to evaluate query results using LLM at different distance thresholds.
Measures result quality, relevance, and filtering effectiveness.
"""

import json
import sys
import logging
from typing import List, Dict, Any, Optional
import requests

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Import after path setup
from app.config import settings
from app.services.llm_service import LLMService


class QueryEvaluator:
    """Evaluate query results using LLM at different distance thresholds."""
    
    def __init__(self, api_base_url: str = "http://localhost:8000", course_id: str = "default"):
        """Initialize evaluator.
        
        Args:
            api_base_url: Base URL of RAG API
            course_id: Course ID to evaluate
        """
        self.api_base_url = api_base_url
        self.course_id = course_id
        self.llm_service: Optional[LLMService] = None
        
        # Initialize LLM service if available
        if settings.OPENAI_API_KEY:
            try:
                self.llm_service = LLMService(
                    api_key=settings.OPENAI_API_KEY,
                    model=settings.OPENAI_MODEL,
                    temperature=0.3,  # Lower temp for evaluation
                    max_tokens=500,
                )
                logger.info("LLM service initialized for evaluation")
            except Exception as e:
                logger.warning(f"Failed to initialize LLM service: {e}")
        else:
            logger.warning("OPENAI_API_KEY not set - LLM evaluation unavailable")
    
    def query_api(self, query: str, top_k: int = 3, distance_threshold: Optional[float] = None) -> List[Dict[str, Any]]:
        """Query the RAG API.
        
        Args:
            query: Search query
            top_k: Number of results
            distance_threshold: Distance threshold for filtering
        
        Returns:
            List of results with distance scores
        """
        url = f"{self.api_base_url}/courses/{self.course_id}/query"
        payload = {
            "query": query,
            "top_k": top_k,
        }
        if distance_threshold is not None:
            payload["distance_threshold"] = distance_threshold
        
        try:
            response = requests.post(url, json=payload, timeout=10)
            response.raise_for_status()
            return response.json()
        except Exception as e:
            logger.error(f"API query failed: {e}")
            return []
    
    def evaluate_result_relevance(self, query: str, result_content: str) -> Dict[str, Any]:
        """Evaluate single result relevance using LLM.
        
        Args:
            query: Original search query
            result_content: Retrieved document content
        
        Returns:
            Evaluation with relevance score (0-10) and explanation
        """
        if not self.llm_service:
            logger.warning("LLM service not available, skipping evaluation")
            return {
                "relevance_score": None,
                "explanation": "LLM service not available",
                "evaluated": False
            }
        
        prompt = f"""Evaluate the relevance of the following document to the query.

QUERY: {query}

DOCUMENT: {result_content[:500]}  # First 500 chars

Rate relevance on 0-10 scale where:
- 0-2: Not relevant
- 3-4: Tangentially related
- 5-6: Moderately relevant
- 7-8: Highly relevant
- 9-10: Directly answers the query

Respond with exactly:
SCORE: [number]
EXPLANATION: [one sentence]"""
        
        try:
            response = self.llm_service.client.messages.create(
                model=self.llm_service.model,
                max_tokens=200,
                messages=[{"role": "user", "content": prompt}]
            )
            
            response_text = response.content[0].text
            
            # Parse response
            lines = response_text.strip().split('\n')
            score = None
            explanation = ""
            
            for line in lines:
                if line.startswith("SCORE:"):
                    try:
                        score = int(line.split(":", 1)[1].strip())
                    except (ValueError, IndexError):
                        pass
                elif line.startswith("EXPLANATION:"):
                    explanation = line.split(":", 1)[1].strip()
            
            return {
                "relevance_score": score,
                "explanation": explanation,
                "evaluated": True
            }
        except Exception as e:
            logger.error(f"LLM evaluation failed: {e}")
            return {
                "relevance_score": None,
                "explanation": str(e),
                "evaluated": False
            }
    
    def evaluate_threshold_impact(self, queries: List[str], thresholds: List[float], top_k: int = 5) -> Dict[str, Any]:
        """Evaluate impact of different distance thresholds.
        
        Args:
            queries: Test queries
            thresholds: Distance thresholds to test
            top_k: Results per query
        
        Returns:
            Comprehensive evaluation results
        """
        results = {
            "queries": queries,
            "thresholds": thresholds,
            "evaluations": []
        }
        
        for query in queries:
            logger.info(f"\nEvaluating query: {query}")
            query_results = {
                "query": query,
                "thresholds": {}
            }
            
            for threshold in thresholds:
                logger.info(f"  Threshold: {threshold}")
                
                # Get results from API
                docs = self.query_api(query, top_k=top_k, distance_threshold=threshold)
                
                threshold_eval = {
                    "threshold": threshold,
                    "result_count": len(docs),
                    "distance_stats": {},
                    "relevance_evaluations": []
                }
                
                if docs:
                    # Analyze distance distribution
                    distances = [d.get("distance", 0) for d in docs]
                    threshold_eval["distance_stats"] = {
                        "min": min(distances),
                        "max": max(distances),
                        "mean": sum(distances) / len(distances),
                    }
                    
                    # Evaluate relevance of top result
                    if docs:
                        top_result = docs[0]
                        relevance = self.evaluate_result_relevance(
                            query,
                            top_result.get("page_content", "")
                        )
                        threshold_eval["relevance_evaluations"].append({
                            "rank": 1,
                            "distance": top_result.get("distance"),
                            "evaluation": relevance
                        })
                        
                        if relevance["evaluated"]:
                            logger.info(f"    Top result relevance: {relevance['relevance_score']}/10")
                
                query_results["thresholds"][str(threshold)] = threshold_eval
            
            results["evaluations"].append(query_results)
        
        return results
    
    def generate_report(self, evaluation_results: Dict[str, Any]) -> str:
        """Generate human-readable evaluation report.
        
        Args:
            evaluation_results: Results from evaluate_threshold_impact
        
        Returns:
            Formatted report
        """
        report = []
        report.append("=" * 80)
        report.append("DISTANCE THRESHOLD EVALUATION REPORT")
        report.append("=" * 80)
        
        for query_eval in evaluation_results["evaluations"]:
            report.append(f"\nQuery: {query_eval['query']}")
            report.append("-" * 80)
            
            for threshold_key, threshold_data in query_eval["thresholds"].items():
                threshold = threshold_data["threshold"]
                result_count = threshold_data["result_count"]
                
                report.append(f"\n  Threshold: {threshold}")
                report.append(f"    Results returned: {result_count}")
                
                if threshold_data["distance_stats"]:
                    stats = threshold_data["distance_stats"]
                    report.append(f"    Distance range: {stats['min']:.4f} - {stats['max']:.4f}")
                    report.append(f"    Average distance: {stats['mean']:.4f}")
                
                for rel_eval in threshold_data["relevance_evaluations"]:
                    if rel_eval["evaluation"]["evaluated"]:
                        score = rel_eval["evaluation"]["relevance_score"]
                        expl = rel_eval["evaluation"]["explanation"]
                        report.append(f"    Relevance score: {score}/10")
                        report.append(f"    Explanation: {expl}")
        
        report.append("\n" + "=" * 80)
        return "\n".join(report)


def main():
    """Run evaluation."""
    logger.info("Starting distance threshold evaluation...")
    
    # Test queries
    test_queries = [
        "What is machine learning?",
        "How does neural network training work?",
        "Explain artificial intelligence applications",
    ]
    
    # Thresholds to test
    test_thresholds = [0.3, 0.5, 0.7, 0.9, 1.0]
    
    # Initialize evaluator
    evaluator = QueryEvaluator()
    
    # Check API is accessible
    try:
        response = requests.get(f"{evaluator.api_base_url}/courses", timeout=5)
        response.raise_for_status()
        logger.info("API is accessible")
    except Exception as e:
        logger.error(f"Cannot reach API at {evaluator.api_base_url}: {e}")
        logger.info("Make sure the API is running: uvicorn app.main:app --reload")
        sys.exit(1)
    
    # Run evaluation
    results = evaluator.evaluate_threshold_impact(test_queries, test_thresholds, top_k=5)
    
    # Generate and print report
    report = evaluator.generate_report(results)
    print(report)
    
    # Save results to file
    output_file = "evaluation_results.json"
    with open(output_file, "w") as f:
        json.dump(results, f, indent=2)
    logger.info(f"Results saved to {output_file}")


if __name__ == "__main__":
    main()
