# tests/test_llm_api.py
"""Integration tests for LLM chat API endpoint."""
import pytest
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from langchain_core.messages import AIMessage

from app import main
from app.config import settings


class TestChatEndpointSuccess:
    """Test successful chat endpoint behavior."""

    def test_chat_basic_request(self, client, sample_rag_results):
        """Test basic chat request."""
        # Initialize LLM service for the test
        with patch('app.main.llm_service') as mock_llm_svc:
            mock_llm_svc.synthesize = MagicMock(
                return_value=MagicMock(
                    answer="Machine learning is a subset of AI that enables systems to learn.",
                    sources=[
                        MagicMock(
                            content=doc["page_content"],
                            filename=doc["metadata"]["source"],
                            chunk_id=doc["metadata"].get("id"),
                        )
                        for doc in sample_rag_results
                    ]
                )
            )
            
            # Temporarily set llm_service for this test
            original_llm_service = main.llm_service
            main.llm_service = mock_llm_svc
            
            try:
                response = client.post(
                    "/courses/default/llm/chat",
                    json={"query": "What is machine learning?", "top_k": 3}
                )
                
                assert response.status_code == 200
                data = response.json()
                assert "answer" in data
                assert "sources" in data
                assert len(data["sources"]) == 3
            finally:
                main.llm_service = original_llm_service

    def test_chat_with_custom_guidelines(self, client, sample_rag_results):
        """Test chat request with custom guidelines."""
        with patch('app.main.llm_service') as mock_llm_svc:
            mock_llm_svc.synthesize = MagicMock(
                return_value=MagicMock(
                    answer="Test answer",
                    sources=[]
                )
            )
            
            original_llm_service = main.llm_service
            main.llm_service = mock_llm_svc
            
            try:
                response = client.post(
                    "/courses/default/llm/chat",
                    json={
                        "query": "Explain neural networks",
                        "top_k": 2,
                        "guidelines": "Use simple language"
                    }
                )
                
                assert response.status_code == 200
                # Verify guidelines were passed to LLM service
                mock_llm_svc.synthesize.assert_called_once()
                call_kwargs = mock_llm_svc.synthesize.call_args[1]
                assert call_kwargs["guidelines"] == "Use simple language"
            finally:
                main.llm_service = original_llm_service

    def test_chat_response_structure(self, client, sample_rag_results):
        """Test chat response structure is correct."""
        with patch('app.main.llm_service') as mock_llm_svc:
            mock_synthesis = MagicMock()
            mock_synthesis.answer = "Test answer about AI"
            source1 = MagicMock()
            source1.content = "AI is intelligence"
            source1.filename = "ai.txt"
            source1.chunk_id = "chunk_001"
            mock_synthesis.sources = [source1]
            
            mock_llm_svc.synthesize = MagicMock(return_value=mock_synthesis)
            
            original_llm_service = main.llm_service
            main.llm_service = mock_llm_svc
            
            try:
                response = client.post(
                    "/courses/default/llm/chat",
                    json={"query": "What is AI?"}
                )
                
                assert response.status_code == 200
                data = response.json()
                
                # Check answer
                assert isinstance(data["answer"], str)
                assert "Test answer" in data["answer"]
                
                # Check sources
                assert isinstance(data["sources"], list)
                assert len(data["sources"]) == 1
                source = data["sources"][0]
                assert "content" in source
                assert "filename" in source
                assert "chunk_id" in source
                assert source["filename"] == "ai.txt"
            finally:
                main.llm_service = original_llm_service


class TestChatEndpointErrors:
    """Test chat endpoint error handling."""

    def test_chat_without_llm_service(self, client):
        """Test chat when LLM service is not initialized."""
        original_llm_service = main.llm_service
        main.llm_service = None
        
        try:
            response = client.post(
                "/courses/default/llm/chat",
                json={"query": "What is AI?"}
            )
            
            assert response.status_code == 503
            assert "LLM service not initialized" in response.json()["detail"]
        finally:
            main.llm_service = original_llm_service

    def test_chat_without_vectorstore(self, client_uninitialized):
        """Test chat when vectorstore is not initialized."""
        # This client has no vectorstore
        response = client_uninitialized.post(
            "/courses/default/llm/chat",
            json={"query": "What is AI?"}
        )
        
        assert response.status_code == 503
        assert "Vector store not initialized" in response.json()["detail"]

    def test_chat_no_relevant_documents(self, client):
        """Test chat when no relevant documents are found."""
        # Mock query service to return empty results
        original_query_service = main.query_services.get("default")
        mock_query_service = MagicMock()
        mock_query_service.vectorstore = MagicMock()
        mock_query_service.search = MagicMock(return_value=[])
        main.query_services["default"] = mock_query_service
        
        original_llm_service = main.llm_service
        main.llm_service = MagicMock()
        
        try:
            response = client.post(
                "/courses/default/llm/chat",
                json={"query": "Obscure topic XYZ"}
            )
            
            assert response.status_code == 404
            assert "No relevant documents found" in response.json()["detail"]
        finally:
            main.llm_service = original_llm_service
            if original_query_service:
                main.query_services["default"] = original_query_service
            else:
                main.query_services.pop("default", None)

    def test_chat_llm_synthesis_failure(self, client):
        """Test chat when LLM synthesis fails."""
        original_llm_service = main.llm_service
        mock_llm_service = MagicMock()
        mock_llm_service.synthesize = MagicMock(
            side_effect=Exception("OpenAI API error")
        )
        main.llm_service = mock_llm_service
        
        try:
            response = client.post(
                "/courses/default/llm/chat",
                json={"query": "What is machine learning?"}
            )
            
            assert response.status_code == 500
            assert "Chat failed" in response.json()["detail"]
        finally:
            main.llm_service = original_llm_service

    def test_chat_invalid_course(self, client):
        """Test chat with non-existent course."""
        # Try to chat with a course that has no vectorstore
        original_llm_service = main.llm_service
        main.llm_service = MagicMock()
        
        try:
            response = client.post(
                "/courses/nonexistent_course/llm/chat",
                json={"query": "What is AI?"}
            )
            
            # Should fail because vectorstore not initialized
            assert response.status_code == 503
        finally:
            main.llm_service = original_llm_service


class TestChatEndpointValidation:
    """Test chat endpoint input validation."""

    def test_chat_empty_query(self, client):
        """Test chat with empty query."""
        original_llm_service = main.llm_service
        main.llm_service = MagicMock()
        
        try:
            response = client.post(
                "/courses/default/llm/chat",
                json={"query": ""}
            )
            
            # FastAPI should validate empty strings
            # Response depends on whether validators are strict
            # At minimum, should be processed
            assert response.status_code in [200, 400, 422]
        finally:
            main.llm_service = original_llm_service

    def test_chat_missing_query(self, client):
        """Test chat without query parameter."""
        response = client.post(
            "/courses/default/llm/chat",
            json={"top_k": 3}
        )
        
        # Should fail validation
        assert response.status_code == 422

    def test_chat_invalid_top_k(self, client):
        """Test chat with invalid top_k."""
        original_llm_service = main.llm_service
        main.llm_service = MagicMock()
        
        try:
            response = client.post(
                "/courses/default/llm/chat",
                json={"query": "What is AI?", "top_k": -1}
            )
            
            # FastAPI should validate negative integers
            # Should fail validation or be handled gracefully
            assert response.status_code in [200, 400, 422]
        finally:
            main.llm_service = original_llm_service

    def test_chat_with_default_top_k(self, client):
        """Test that default top_k is used."""
        original_llm_service = main.llm_service
        mock_llm_service = MagicMock()
        mock_llm_service.synthesize = MagicMock(
            return_value=MagicMock(answer="Test", sources=[])
        )
        main.llm_service = mock_llm_service
        
        try:
            response = client.post(
                "/courses/default/llm/chat",
                json={"query": "What is AI?"}
            )
            
            # Check that default top_k=3 is used
            if response.status_code == 200:
                call_args = mock_llm_service.synthesize.call_args
                assert call_args is not None
        finally:
            main.llm_service = original_llm_service


class TestChatEndpointIntegration:
    """Test full chat endpoint integration."""

    def test_chat_rag_to_llm_flow(self, client, sample_rag_results):
        """Test complete flow from RAG to LLM."""
        original_llm_service = main.llm_service
        original_query_service = main.query_services.get("default")
        
        # Mock query service to return sample results
        mock_query_service = MagicMock()
        mock_query_service.vectorstore = MagicMock()
        mock_query_service.search = MagicMock(return_value=sample_rag_results)
        main.query_services["default"] = mock_query_service
        
        # Mock LLM service
        mock_llm_service = MagicMock()
        mock_synthesis = MagicMock()
        mock_synthesis.answer = "Machine learning enables systems to learn from data."
        mock_synthesis.sources = [
            MagicMock(
                content=doc["page_content"],
                filename=doc["metadata"]["source"],
                chunk_id=doc["metadata"].get("id"),
            )
            for doc in sample_rag_results
        ]
        mock_llm_service.synthesize = MagicMock(return_value=mock_synthesis)
        main.llm_service = mock_llm_service
        
        try:
            response = client.post(
                "/courses/default/llm/chat",
                json={"query": "What is machine learning?", "top_k": 3}
            )
            
            assert response.status_code == 200
            data = response.json()
            
            # Verify RAG was called
            mock_query_service.search.assert_called_once()
            search_args = mock_query_service.search.call_args
            assert search_args[1]["query"] == "What is machine learning?"
            assert search_args[1]["top_k"] == 3
            
            # Verify LLM was called with results
            mock_llm_service.synthesize.assert_called_once()
            synthesis_args = mock_llm_service.synthesize.call_args
            assert synthesis_args[1]["query"] == "What is machine learning?"
            assert len(synthesis_args[1]["context_docs"]) == 3
            
            # Verify response
            assert "Machine learning enables systems" in data["answer"]
            assert len(data["sources"]) == 3
        finally:
            main.llm_service = original_llm_service
            if original_query_service:
                main.query_services["default"] = original_query_service
            else:
                main.query_services.pop("default", None)

    def test_chat_sources_include_metadata(self, client, sample_rag_results):
        """Test that chat response includes proper source metadata."""
        original_llm_service = main.llm_service
        mock_llm_service = MagicMock()
        
        sources = [
            MagicMock(
                content=doc["page_content"],
                filename=doc["metadata"]["source"],
                chunk_id=doc["metadata"].get("id"),
            )
            for doc in sample_rag_results
        ]
        
        mock_synthesis = MagicMock()
        mock_synthesis.answer = "Answer text"
        mock_synthesis.sources = sources
        mock_llm_service.synthesize = MagicMock(return_value=mock_synthesis)
        main.llm_service = mock_llm_service
        
        try:
            response = client.post(
                "/courses/default/llm/chat",
                json={"query": "Test?"}
            )
            
            assert response.status_code == 200
            data = response.json()
            
            for i, source in enumerate(data["sources"]):
                assert source["content"] == sample_rag_results[i]["page_content"]
                assert source["filename"] == sample_rag_results[i]["metadata"]["source"]
                assert source["chunk_id"] == sample_rag_results[i]["metadata"].get("id")
        finally:
            main.llm_service = original_llm_service
