# tests/test_distance_threshold_integration.py
"""Integration tests for distance threshold feature with real Redis."""

import pytest
import json
from fastapi.testclient import TestClient
from app.main import app
from app.config import settings
from app.main_helpers import chunk_text, save_registry, get_registry_path, get_index_name
import os
from datetime import datetime


@pytest.fixture
def test_course_id():
    """Test course identifier."""
    return "test_threshold_course"


@pytest.fixture
def test_documents():
    """Test documents for ingestion."""
    return {
        "ml_basics.txt": """Machine Learning is a subset of artificial intelligence.
        It enables systems to learn and improve from experience.
        Machine learning algorithms build mathematical models based on sample data.
        The field of machine learning is concerned with the question of how to 
        construct computer programs that automatically improve with experience.""",
        
        "deep_learning.txt": """Deep learning is a subset of machine learning.
        It is based on artificial neural networks with multiple layers.
        Deep neural networks can learn hierarchical representations of data.
        Deep learning has been successful in many domains including computer vision,
        natural language processing, and speech recognition.""",
        
        "nlp_guide.txt": """Natural Language Processing is a field of AI.
        It focuses on enabling computers to understand and process human language.
        NLP techniques include tokenization, parsing, and semantic analysis.
        Applications of NLP include machine translation, sentiment analysis, 
        and question answering systems.""",
    }


@pytest.fixture
def setup_test_course(client: TestClient, test_course_id, test_documents):
    """Setup a test course with sample documents."""
    # Create data directory if needed
    os.makedirs(settings.DATA_DIR, exist_ok=True)
    
    # Upload documents
    for filename, content in test_documents.items():
        filepath = os.path.join(settings.DATA_DIR, filename)
        with open(filepath, "w") as f:
            f.write(content)
    
    # Ingest documents
    for filename, content in test_documents.items():
        response = client.post(
            f"/courses/{test_course_id}/asset/add",
            files={"file": (filename, content)},
        )
        assert response.status_code == 200, f"Failed to upload {filename}: {response.text}"
    
    yield test_course_id
    
    # Cleanup
    for filename in test_documents.keys():
        filepath = os.path.join(settings.DATA_DIR, filename)
        if os.path.exists(filepath):
            os.remove(filepath)


class TestDistanceThresholdIntegration:
    """Integration tests with real Redis vectorstore."""
    
    def test_query_returns_distance_scores(self, client: TestClient, setup_test_course):
        """Test that query results include distance scores."""
        course_id = setup_test_course
        
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "machine learning", "top_k": 3}
        )
        
        assert response.status_code == 200
        results = response.json()
        assert len(results) > 0
        
        # Verify distance field
        for result in results:
            assert "distance" in result
            assert isinstance(result["distance"], (int, float))
            assert 0.0 <= result["distance"] <= 1.0
            assert "page_content" in result
            assert "metadata" in result
    
    def test_threshold_filters_results(self, client: TestClient, setup_test_course):
        """Test that distance threshold filters results correctly."""
        course_id = setup_test_course
        
        # Get all results without threshold
        all_results = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning", "top_k": 10, "distance_threshold": 1.0}
        ).json()
        
        # Get results with tight threshold
        filtered_results = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning", "top_k": 10, "distance_threshold": 0.3}
        ).json()
        
        # Filtered results should be subset of all results
        assert len(filtered_results) <= len(all_results)
        
        # All filtered results should be below threshold
        for result in filtered_results:
            assert result["distance"] <= 0.3
    
    def test_threshold_zero_returns_no_results(self, client: TestClient, setup_test_course):
        """Test that threshold 0 returns no results."""
        course_id = setup_test_course
        
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning", "top_k": 10, "distance_threshold": 0.0}
        )
        
        assert response.status_code == 200
        results = response.json()
        assert len(results) == 0  # Nothing matches distance == 0
    
    def test_threshold_one_returns_all_results(self, client: TestClient, setup_test_course):
        """Test that threshold 1.0 returns maximum results."""
        course_id = setup_test_course
        
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning", "top_k": 10, "distance_threshold": 1.0}
        )
        
        assert response.status_code == 200
        results = response.json()
        assert len(results) > 0
    
    def test_config_defaults_applied(self, client: TestClient, setup_test_course, monkeypatch):
        """Test that config defaults are applied when not in request."""
        course_id = setup_test_course
        
        # Set tight threshold in config
        monkeypatch.setenv("EMBEDDING_DISTANCE_THRESHOLD", "0.4")
        
        # Reload config
        import importlib
        import app.config
        importlib.reload(app.config)
        
        # Query without specifying threshold
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning", "top_k": 10}
        )
        
        assert response.status_code == 200
        results = response.json()
        
        # All results should respect the config threshold
        for result in results:
            assert result["distance"] <= 0.4
    
    def test_query_with_all_parameters(self, client: TestClient, setup_test_course):
        """Test query with all parameters specified."""
        course_id = setup_test_course
        
        response = client.post(
            f"/courses/{course_id}/query",
            json={
                "query": "neural networks",
                "top_k": 5,
                "distance_threshold": 0.6
            }
        )
        
        assert response.status_code == 200
        results = response.json()
        
        # Verify constraints
        assert len(results) <= 5
        for result in results:
            assert result["distance"] <= 0.6
            assert "page_content" in result
            assert "metadata" in result
            assert "distance" in result
    
    def test_distance_ordering(self, client: TestClient, setup_test_course):
        """Test that results are ordered by distance (best first)."""
        course_id = setup_test_course
        
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "neural", "top_k": 10, "distance_threshold": 1.0}
        )
        
        assert response.status_code == 200
        results = response.json()
        
        if len(results) > 1:
            # Check ordering: lower distance first
            distances = [r["distance"] for r in results]
            assert distances == sorted(distances), "Results should be ordered by distance (ascending)"
    
    def test_similar_queries_have_consistent_distances(self, client: TestClient, setup_test_course):
        """Test that similar queries return similar distance patterns."""
        course_id = setup_test_course
        
        # Query 1
        r1 = client.post(
            f"/courses/{course_id}/query",
            json={"query": "machine learning algorithms", "top_k": 3}
        ).json()
        
        # Query 2 (similar)
        r2 = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning algorithm", "top_k": 3}
        ).json()
        
        assert len(r1) > 0 and len(r2) > 0
        
        # Both should find the ML document as highly relevant (low distance)
        # Just verify both have distance scores
        for result in r1 + r2:
            assert "distance" in result
            assert isinstance(result["distance"], (int, float))


class TestDistanceThresholdEdgeCases:
    """Test edge cases and error handling."""
    
    def test_negative_threshold_handled(self, client: TestClient, setup_test_course):
        """Test handling of invalid negative threshold."""
        course_id = setup_test_course
        
        # Negative threshold should be treated as 0
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning", "distance_threshold": -0.5}
        )
        
        # Should still work
        assert response.status_code == 200
    
    def test_threshold_over_one(self, client: TestClient, setup_test_course):
        """Test threshold > 1.0 (all results pass)."""
        course_id = setup_test_course
        
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning", "top_k": 10, "distance_threshold": 1.5}
        )
        
        assert response.status_code == 200
        results = response.json()
        assert len(results) > 0
    
    def test_empty_query_with_threshold(self, client: TestClient, setup_test_course):
        """Test empty query string."""
        course_id = setup_test_course
        
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "", "distance_threshold": 0.5}
        )
        
        # Should handle gracefully
        assert response.status_code in [200, 400]
    
    def test_very_large_top_k(self, client: TestClient, setup_test_course):
        """Test with very large top_k."""
        course_id = setup_test_course
        
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning", "top_k": 1000}
        )
        
        assert response.status_code == 200
        results = response.json()
        # Should return available results, not 1000
        assert len(results) >= 0


class TestDistanceThresholdMetrics:
    """Test metrics and analysis of distance threshold behavior."""
    
    def test_distance_distribution_analysis(self, client: TestClient, setup_test_course):
        """Test analyzing distance distribution across results."""
        course_id = setup_test_course
        
        response = client.post(
            f"/courses/{course_id}/query",
            json={"query": "learning", "top_k": 20, "distance_threshold": 1.0}
        )
        
        assert response.status_code == 200
        results = response.json()
        
        if len(results) > 1:
            distances = [r["distance"] for r in results]
            
            # Calculate metrics
            min_dist = min(distances)
            max_dist = max(distances)
            avg_dist = sum(distances) / len(distances)
            
            assert 0.0 <= min_dist <= max_dist <= 1.0
            assert min_dist <= avg_dist <= max_dist
            
            # Results should show diversity in distances
            assert max_dist - min_dist > 0
    
    def test_threshold_filtering_statistics(self, client: TestClient, setup_test_course):
        """Test statistics of threshold filtering."""
        course_id = setup_test_course
        
        thresholds = [0.2, 0.4, 0.6, 0.8, 1.0]
        threshold_results = {}
        
        for threshold in thresholds:
            response = client.post(
                f"/courses/{course_id}/query",
                json={"query": "learning", "top_k": 20, "distance_threshold": threshold}
            )
            
            assert response.status_code == 200
            results = response.json()
            threshold_results[threshold] = len(results)
        
        # Results should increase with higher threshold
        result_counts = [threshold_results[t] for t in thresholds]
        for i in range(len(result_counts) - 1):
            assert result_counts[i] <= result_counts[i + 1], \
                f"Higher threshold should return more results: {threshold_results}"
