# tests/test_distance_threshold.py
"""Test distance threshold filtering for embeddings."""
import pytest
from fastapi.testclient import TestClient
from app.main import app
from app.services.query_service import QueryService
from langchain_redis import RedisVectorStore
from langchain_huggingface import HuggingFaceEmbeddings


class TestDistanceThresholdConfiguration:
    """Test distance threshold configuration and defaults."""
    
    def test_config_distance_threshold_default(self):
        """Test default distance threshold in config."""
        from app.config import settings
        assert hasattr(settings, 'EMBEDDING_DISTANCE_THRESHOLD')
        assert isinstance(settings.EMBEDDING_DISTANCE_THRESHOLD, float)
        assert 0.0 <= settings.EMBEDDING_DISTANCE_THRESHOLD <= 1.0
    
    def test_config_top_k_default(self):
        """Test default top_k in config."""
        from app.config import settings
        assert hasattr(settings, 'EMBEDDING_TOP_K')
        assert isinstance(settings.EMBEDDING_TOP_K, int)
        assert settings.EMBEDDING_TOP_K > 0
    
    def test_config_environment_override(self, monkeypatch):
        """Test overriding config via environment variables."""
        monkeypatch.setenv("EMBEDDING_DISTANCE_THRESHOLD", "0.8")
        monkeypatch.setenv("EMBEDDING_TOP_K", "5")
        
        # Reimport to get new config
        import importlib
        import app.config
        importlib.reload(app.config)
        
        assert app.config.settings.EMBEDDING_DISTANCE_THRESHOLD == 0.8
        assert app.config.settings.EMBEDDING_TOP_K == 5


class TestQueryRequestModel:
    """Test QueryRequest model with distance threshold."""
    
    def test_query_request_with_distance_threshold(self):
        """Test QueryRequest accepts distance_threshold."""
        from app.main import QueryRequest
        
        req = QueryRequest(
            query="test query",
            top_k=5,
            distance_threshold=0.6
        )
        assert req.query == "test query"
        assert req.top_k == 5
        assert req.distance_threshold == 0.6
    
    def test_query_request_distance_threshold_optional(self):
        """Test distance_threshold is optional."""
        from app.main import QueryRequest
        
        req = QueryRequest(query="test query")
        assert req.query == "test query"
        assert req.top_k is None
        assert req.distance_threshold is None


class TestDocumentResponseModel:
    """Test DocumentResponse model with distance field."""
    
    def test_document_response_with_distance(self):
        """Test DocumentResponse includes distance."""
        from app.main import DocumentResponse
        
        doc = DocumentResponse(
            page_content="sample content",
            metadata={"source": "file.txt"},
            distance=0.35
        )
        assert doc.page_content == "sample content"
        assert doc.metadata == {"source": "file.txt"}
        assert doc.distance == 0.35
    
    def test_document_response_distance_optional(self):
        """Test distance field is optional."""
        from app.main import DocumentResponse
        
        doc = DocumentResponse(
            page_content="sample content",
            metadata={"source": "file.txt"}
        )
        assert doc.distance is None


class TestQueryServiceDistanceFiltering:
    """Test QueryService distance threshold filtering."""
    
    def test_search_without_threshold(self, mock_vectorstore_with_scores):
        """Test search returns all results when no threshold."""
        qs = QueryService(vectorstore=mock_vectorstore_with_scores)
        results = qs.search("test", top_k=3)
        
        assert len(results) == 3
        assert all("distance" in r for r in results)
        assert all("page_content" in r for r in results)
    
    def test_search_with_high_threshold(self, mock_vectorstore_with_scores):
        """Test search with high threshold filters results."""
        qs = QueryService(vectorstore=mock_vectorstore_with_scores)
        results = qs.search("test", top_k=3, distance_threshold=0.9)
        
        # With high threshold (0.9), all results should pass
        assert len(results) == 3
    
    def test_search_with_low_threshold(self, mock_vectorstore_with_scores):
        """Test search with low threshold filters out distant results."""
        qs = QueryService(vectorstore=mock_vectorstore_with_scores)
        results = qs.search("test", top_k=3, distance_threshold=0.3)
        
        # With low threshold, some results should be filtered
        assert len(results) <= 3
        assert all(r["distance"] <= 0.3 for r in results)
    
    def test_search_distance_included_in_result(self, mock_vectorstore_with_scores):
        """Test that distance is included in results."""
        qs = QueryService(vectorstore=mock_vectorstore_with_scores)
        results = qs.search("test", top_k=3, distance_threshold=0.8)
        
        for result in results:
            assert "distance" in result
            assert isinstance(result["distance"], float)
            assert 0.0 <= result["distance"] <= 1.0


class TestQueryAPIWithThreshold:
    """Test query API endpoint with distance threshold."""
    
    def test_query_endpoint_default_threshold(self, client: TestClient, setup_test_data):
        """Test query endpoint uses default threshold from config."""
        course_id = "test_course"
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "test query"}
        )
        
        if response.status_code == 200:
            data = response.json()
            assert isinstance(data, list)
            for doc in data:
                assert "distance" in doc
    
    def test_query_endpoint_custom_threshold(self, client: TestClient, setup_test_data):
        """Test query endpoint accepts custom distance threshold."""
        course_id = "test_course"
        response = client.post(
            f"/courses/{course_id}/query",
            json={
                "query": "test query",
                "top_k": 5,
                "distance_threshold": 0.7
            }
        )
        
        if response.status_code == 200:
            data = response.json()
            assert isinstance(data, list)
            # All results should have distance <= 0.7
            for doc in data:
                assert "distance" in doc
                assert doc["distance"] <= 0.7
    
    def test_query_endpoint_custom_top_k(self, client: TestClient, setup_test_data):
        """Test query endpoint respects custom top_k."""
        course_id = "test_course"
        response = client.post(
            f"/courses/{course_id}/query",
            json={
                "query": "test query",
                "top_k": 2
            }
        )
        
        if response.status_code == 200:
            data = response.json()
            assert len(data) <= 2


class TestLLMEvaluationOfResults:
    """Test using LLM to evaluate query result quality."""
    
    @pytest.mark.skipif(
        not __import__('app.config', fromlist=['settings']).settings.OPENAI_API_KEY,
        reason="OPENAI_API_KEY not set"
    )
    def test_evaluate_result_relevance_with_llm(self, llm_service_fixture):
        """Test LLM evaluation of result relevance."""
        from app.services.llm_service import LLMService
        
        llm_service = llm_service_fixture
        if llm_service is None:
            pytest.skip("LLM service not available")
        
        # Sample query and result
        query = "What is machine learning?"
        result_content = "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience."
        
        # Evaluate relevance
        evaluation_prompt = f"""
        Query: {query}
        Retrieved Document: {result_content}
        
        Rate the relevance of this document to the query on a scale of 0-10.
        Also provide a brief explanation.
        
        Format your response as:
        RELEVANCE_SCORE: [0-10]
        EXPLANATION: [your explanation]
        """
        
        response = llm_service.client.messages.create(
            model=llm_service.model,
            max_tokens=200,
            messages=[{"role": "user", "content": evaluation_prompt}]
        )
        
        response_text = response.content[0].text
        assert "RELEVANCE_SCORE" in response_text
        assert "EXPLANATION" in response_text


class TestDistanceMetricsAnalysis:
    """Test analysis and reporting of distance metrics."""
    
    def test_distance_distribution_analysis(self, mock_vectorstore_with_scores):
        """Test analyzing distance distribution of results."""
        qs = QueryService(vectorstore=mock_vectorstore_with_scores)
        results = qs.search("test", top_k=10, distance_threshold=1.0)
        
        distances = [r["distance"] for r in results]
        assert len(distances) > 0
        assert min(distances) >= 0.0
        assert max(distances) <= 1.0
        
        # Calculate stats
        avg_distance = sum(distances) / len(distances)
        assert 0.0 <= avg_distance <= 1.0
    
    def test_threshold_impact_on_result_count(self, mock_vectorstore_with_scores):
        """Test how threshold impacts number of results."""
        qs = QueryService(vectorstore=mock_vectorstore_with_scores)
        
        thresholds = [0.2, 0.4, 0.6, 0.8, 1.0]
        result_counts = []
        
        for threshold in thresholds:
            results = qs.search("test", top_k=10, distance_threshold=threshold)
            result_counts.append(len(results))
        
        # Result count should generally increase with threshold
        # (though not strictly monotonic due to limited data)
        assert result_counts[-1] >= result_counts[0]


@pytest.fixture
def mock_vectorstore_with_scores(mocker):
    """Create mock vectorstore that returns results with distance scores."""
    from unittest.mock import Mock, MagicMock
    from langchain_core.documents import Document
    
    mock_vs = MagicMock()
    
    # Mock data with varying distances
    mock_docs = [
        (Document(page_content="Content 1", metadata={"source": "doc1.txt"}), 0.15),
        (Document(page_content="Content 2", metadata={"source": "doc2.txt"}), 0.35),
        (Document(page_content="Content 3", metadata={"source": "doc3.txt"}), 0.55),
        (Document(page_content="Content 4", metadata={"source": "doc4.txt"}), 0.75),
        (Document(page_content="Content 5", metadata={"source": "doc5.txt"}), 0.95),
    ]
    
    mock_vs.similarity_search_with_score = Mock(return_value=mock_docs[:3])
    
    return mock_vs


@pytest.fixture
def setup_test_data(client: TestClient):
    """Setup test data in vectorstore."""
    # This would normally load test documents
    # For now, just provide the fixture
    yield
    # Cleanup
    pass


@pytest.fixture
def llm_service_fixture():
    """Get LLM service if available."""
    try:
        from app.services.llm_service import LLMService
        from app.config import settings
        
        if settings.OPENAI_API_KEY:
            return LLMService(
                api_key=settings.OPENAI_API_KEY,
                model=settings.OPENAI_MODEL,
                temperature=settings.LLM_TEMPERATURE,
                max_tokens=settings.LLM_MAX_TOKENS,
            )
    except Exception as e:
        print(f"LLM service not available: {e}")
    
    return None
